Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch', 'naive_softmax'],  # possible values for `line_arg``
        line_names=["Triton", "Torch", "Naive Softmax"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'naive_softmax':
        ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch  Naive Softmax
0     256.0   475.166542   675.218075     206.112331
1     384.0   655.123930   829.318041     261.964284
2     512.0   820.240936   920.826918     300.119709
3     640.0   919.818146   910.222212     329.405789
4     768.0   988.500953   980.283489     348.762065
5     896.0  1046.180499  1034.781675     353.955498
6    1024.0  1085.341715  1082.789026     352.988781
7    1152.0  1088.758877  1077.305928     347.923159
8    1280.0  1125.320758  1109.671632     349.082970
9    1408.0  1156.547509  1140.081001     340.524440
10   1536.0  1195.061126  1159.202653     333.435594
11   1664.0  1214.181932  1183.183909     329.563218
12   1792.0  1225.788360  1191.537175     325.590999
13   1920.0  1260.929975  1224.866509     324.211874
14   2048.0  1269.980596  1243.374771     324.300319
15   2176.0  1239.481839   960.011808     325.278138
16   2304.0  1256.071190  1004.069000     325.824058
17   2432.0  1271.804802  1034.925234     326.790521
18   2560.0  1285.339836  1067.259125     327.731295
19   2688.0  1294.067511  1100.757789     328.664549
20   2816.0  1306.664337  1124.630729     329.752174
21   2944.0  1321.303487  1147.969438     331.210784
22   3072.0  1322.431384  1175.069674     332.841686
23   3200.0  1334.604891  1169.863007     334.714863
24   3328.0  1347.131181  1199.093927     336.254812
25   3456.0  1356.477369  1226.493753     336.719658
26   3584.0  1358.626995  1247.479158     337.674973
27   3712.0  1371.647801  1263.625509     340.304252
28   3840.0  1369.816780  1280.972247     340.121546
29   3968.0  1375.528414  1300.882402     340.857667
30   4096.0  1386.007011  1319.005245     338.494617
31   4224.0  1328.530186  1279.987447     343.407694
32   4352.0  1342.233320  1298.143303     345.206784
33   4480.0  1347.923755  1320.795082     345.862487
34   4608.0  1359.617221  1336.947806     346.801983
35   4736.0  1353.313026  1348.619056     347.983447
36   4864.0  1368.299874  1359.355064     348.815461
37   4992.0  1371.427217  1374.804306     349.770660
38   5120.0  1374.992917  1386.795634     350.996853
39   5248.0  1374.226440  1355.796792     351.547410
40   5376.0  1378.554977  1369.527929     351.521772
41   5504.0  1376.802714  1387.294885     353.572592
42   5632.0  1392.038419  1395.118961     353.055168
43   5760.0  1394.849569  1407.915898     354.943088
44   5888.0  1393.439800  1406.941706     354.985594
45   6016.0  1399.191286  1424.789814     356.708391
46   6144.0  1406.135293  1434.852743     357.084182
47   6272.0  1407.924231  1389.846673     357.488702
48   6400.0  1410.331764  1415.362757     358.710882
49   6528.0  1413.901202  1417.482148     358.771423
50   6656.0  1417.531940  1430.927169     359.590513
51   6784.0  1418.330524  1431.930771     360.224651
52   6912.0  1423.713121  1445.207062     360.969399
53   7040.0  1419.455565  1449.944026     360.943004
54   7168.0  1420.082358  1466.447063     361.613843
55   7296.0  1427.423121  1086.122102     362.204935
56   7424.0  1429.852124  1101.918526     362.868641
57   7552.0  1428.321119  1112.382291     363.327109
58   7680.0  1432.570851  1124.138155     363.699493
59   7808.0  1430.566466  1132.011378     364.685028
60   7936.0  1431.831105  1146.323226     364.691308
61   8064.0  1431.492392  1149.162852     364.668855
62   8192.0  1434.785207  1155.191268     363.917400
63   8320.0  1383.931619  1117.103520     362.098143
64   8448.0  1386.548368  1127.731977     362.750969
65   8576.0  1385.551962  1126.687030     363.575272
66   8704.0  1382.688836  1134.045907     364.324784
67   8832.0  1390.409971  1134.128897     365.182896
68   8960.0  1384.893552  1138.200061     365.852057
69   9088.0  1399.933157  1135.963177     367.165173
70   9216.0  1401.714158  1144.239871     367.645039
71   9344.0  1390.608061  1419.232831     367.894427
72   9472.0  1398.644747  1429.941794     368.439350
73   9600.0  1404.299488  1428.562451     369.194911
74   9728.0  1402.438849  1441.918501     370.136485
75   9856.0  1401.180528  1443.792075     370.079887
76   9984.0  1391.217459  1452.395937     370.604705
77  10112.0  1403.336103  1451.577118     370.635158
78  10240.0  1408.840210  1463.563501     371.169325
79  10368.0  1414.884589  1462.547608     369.757665
80  10496.0  1409.902423  1468.434117     370.484073
81  10624.0  1401.148458  1466.643644     370.384282
82  10752.0  1396.693696  1467.963231     370.900729
83  10880.0  1395.752229  1477.250378     371.346187
84  11008.0  1417.250086  1477.100080     372.010157
85  11136.0  1419.920844  1481.750833     372.824045
86  11264.0  1410.755379  1486.420145     372.460934
87  11392.0  1424.095915  1487.538180     373.904506
88  11520.0  1416.377672  1496.383799     373.583461
89  11648.0  1422.967528  1500.403335     374.609043
90  11776.0  1430.630993  1502.691116     374.757456
91  11904.0  1432.965945  1510.090107     375.334162
92  12032.0  1414.554312  1508.802320     375.739168
93  12160.0  1414.560959  1513.937107     375.532153
94  12288.0  1424.370203  1417.793602     375.967891
95  12416.0  1437.123832  1396.183056     374.712168
96  12544.0  1444.976847  1395.603119     375.687499
97  12672.0  1439.138774  1392.988592     375.176450
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 37.774 seconds)

Gallery generated by Sphinx-Gallery