Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch'],  # possible values for `line_arg``
        line_names=[
            "Triton",
            "Torch",
        ],  # label name for the lines
        styles=[('blue', '-'), ('green', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch
0     256.0   468.525462   706.593640
1     384.0   612.033628   820.719852
2     512.0   762.467266   918.101833
3     640.0   811.147301   961.043509
4     768.0   886.393504  1020.536303
5     896.0   952.535379  1075.271109
6    1024.0  1009.066185  1115.058111
7    1152.0  1099.714890   614.148310
8    1280.0  1152.360083   667.119022
9    1408.0  1157.898410   726.046832
10   1536.0  1186.686291   779.237050
11   1664.0  1221.174560   810.666662
12   1792.0  1233.178123   860.123075
13   1920.0  1248.900661   910.654618
14   2048.0  1278.688234   960.565219
15   2176.0  1241.460066   974.425788
16   2304.0  1249.326045  1012.433184
17   2432.0  1282.541506  1058.271199
18   2560.0  1282.890237  1089.501329
19   2688.0  1297.276822  1101.514696
20   2816.0  1305.926868  1130.480108
21   2944.0  1314.874736  1169.750034
22   3072.0  1335.013006  1184.625721
23   3200.0  1333.402456  1192.356121
24   3328.0  1339.395946  1228.571170
25   3456.0  1354.827941  1246.445136
26   3584.0  1356.554660  1259.200679
27   3712.0  1367.376117  1274.364256
28   3840.0  1370.577873  1298.164587
29   3968.0  1377.633261  1314.526756
30   4096.0  1382.684029  1328.490755
31   4224.0  1340.454745  1160.213260
32   4352.0  1332.913820  1176.221686
33   4480.0  1351.998559  1185.319171
34   4608.0  1362.085333  1197.152115
35   4736.0  1364.513792  1201.355460
36   4864.0  1376.758264  1223.927407
37   4992.0  1368.632676  1235.950819
38   5120.0  1374.251484  1252.467141
39   5248.0  1375.614748  1260.808643
40   5376.0  1379.294446  1286.907848
41   5504.0  1380.427699  1300.938687
42   5632.0  1386.688725  1316.266628
43   5760.0  1395.847816  1323.983693
44   5888.0  1388.921393  1342.447678
45   6016.0  1399.401344  1354.084438
46   6144.0  1409.490053  1376.076502
47   6272.0  1416.754813  1373.426830
48   6400.0  1418.377589  1389.615622
49   6528.0  1414.806780  1397.413756
50   6656.0  1422.453318  1403.039382
51   6784.0  1412.176698  1416.345928
52   6912.0  1428.399821  1426.025129
53   7040.0  1424.151392  1432.381591
54   7168.0  1427.671537  1436.063840
55   7296.0  1433.001417  1444.299869
56   7424.0  1430.516974  1444.240501
57   7552.0  1430.051127  1453.501460
58   7680.0  1435.939215  1462.386600
59   7808.0  1436.211969  1464.431365
60   7936.0  1436.482017  1465.563204
61   8064.0  1443.378996  1472.689937
62   8192.0  1438.849835  1483.776844
63   8320.0  1383.442444  1400.922030
64   8448.0  1372.289864  1402.527918
65   8576.0  1387.675346  1397.870853
66   8704.0  1384.611369  1402.143431
67   8832.0  1378.601729  1404.727793
68   8960.0  1391.895847  1413.714786
69   9088.0  1402.579864  1417.100907
70   9216.0  1396.239680  1425.894734
71   9344.0  1395.964542  1421.808900
72   9472.0  1396.542217  1434.766971
73   9600.0  1391.339223  1432.800387
74   9728.0  1400.931261  1441.334590
75   9856.0  1407.491223  1445.324325
76   9984.0  1396.718897  1447.569104
77  10112.0  1410.516498  1458.217184
78  10240.0  1411.098022  1466.663896
79  10368.0  1409.218281  1466.624640
80  10496.0  1411.795365  1468.286548
81  10624.0  1410.777555  1466.253615
82  10752.0  1400.341882  1474.331605
83  10880.0  1397.879109  1480.252035
84  11008.0  1417.128770  1481.411705
85  11136.0  1417.145322  1483.387262
86  11264.0  1423.102333  1488.682387
87  11392.0  1410.360430  1490.513388
88  11520.0  1418.324676  1494.304419
89  11648.0  1419.401204  1501.700954
90  11776.0  1425.657314  1500.849330
91  11904.0  1436.706558  1508.994477
92  12032.0  1417.533839  1509.353186
93  12160.0  1419.888050  1513.509235
94  12288.0  1427.373825  1392.649035
95  12416.0  1448.379662  1389.611285
96  12544.0  1438.620630  1395.112607
97  12672.0  1444.140489  1391.291386
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 23.151 seconds)

Gallery generated by Sphinx-Gallery