Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch'],  # possible values for `line_arg``
        line_names=[
            "Triton",
            "Torch",
        ],  # label name for the lines
        styles=[('blue', '-'), ('green', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch
0     256.0   464.688720   692.938580
1     384.0   654.126790   825.132325
2     512.0   818.300029   924.469771
3     640.0   871.229190   957.468235
4     768.0   962.878166  1018.703574
5     896.0  1011.995003  1073.750907
6    1024.0  1056.050178  1108.067636
7    1152.0  1105.463698  1031.631006
8    1280.0  1146.149830  1070.598375
9    1408.0  1159.351170  1105.330973
10   1536.0  1191.687465  1132.787967
11   1664.0  1217.097136  1173.909889
12   1792.0  1240.603501  1195.301412
13   1920.0  1253.470703  1203.031238
14   2048.0  1283.268784  1226.030734
15   2176.0  1242.632716   962.070986
16   2304.0  1248.049885   999.609239
17   2432.0  1274.926347  1041.340876
18   2560.0  1281.169611  1072.370670
19   2688.0  1296.814748  1094.364012
20   2816.0  1298.172169  1121.491896
21   2944.0  1319.227565  1146.803676
22   3072.0  1330.097481  1171.360905
23   3200.0  1332.876867  1171.775419
24   3328.0  1339.956344  1203.119119
25   3456.0  1358.062302  1221.891395
26   3584.0  1354.210716  1246.541922
27   3712.0  1360.723908  1268.156550
28   3840.0  1373.000826  1285.773025
29   3968.0  1372.416746  1300.309059
30   4096.0  1378.063787  1317.890421
31   4224.0  1330.286526  1292.293692
32   4352.0  1334.506708  1318.067018
33   4480.0  1339.301499  1335.662865
34   4608.0  1350.503850  1354.550351
35   4736.0  1350.860060  1365.431206
36   4864.0  1363.951501  1378.884124
37   4992.0  1362.394672  1391.994112
38   5120.0  1359.223254  1409.694996
39   5248.0  1370.180858  1363.435643
40   5376.0  1368.095426  1376.192149
41   5504.0  1377.760742  1391.979957
42   5632.0  1385.059956  1400.521743
43   5760.0  1392.100964  1420.562245
44   5888.0  1391.466418  1425.303077
45   6016.0  1392.184857  1442.840343
46   6144.0  1396.006571  1450.822235
47   6272.0  1402.748404  1413.408445
48   6400.0  1413.048644  1423.461539
49   6528.0  1398.573334  1432.581200
50   6656.0  1411.208438  1438.838713
51   6784.0  1412.941262  1432.763064
52   6912.0  1418.667368  1447.672571
53   7040.0  1413.166605  1465.075031
54   7168.0  1411.101564  1469.341518
55   7296.0  1422.599712  1085.623578
56   7424.0  1423.367495  1100.738953
57   7552.0  1420.856437  1112.232121
58   7680.0  1429.991917  1125.328773
59   7808.0  1420.944846  1137.191547
60   7936.0  1432.304065  1146.355675
61   8064.0  1429.253738  1152.925304
62   8192.0  1436.852184  1156.043964
63   8320.0  1382.700179  1111.722960
64   8448.0  1381.481276  1124.085935
65   8576.0  1385.031773  1122.551912
66   8704.0  1383.165577  1128.864932
67   8832.0  1387.880932  1129.187059
68   8960.0  1393.675899  1136.138317
69   9088.0  1406.009666  1131.415370
70   9216.0  1388.189750  1130.316407
71   9344.0  1397.507824  1420.154356
72   9472.0  1386.354354  1432.468978
73   9600.0  1392.742611  1430.037456
74   9728.0  1400.073400  1439.349554
75   9856.0  1396.557020  1441.031911
76   9984.0  1389.653683  1447.830828
77  10112.0  1400.360057  1455.332909
78  10240.0  1419.363776  1461.132698
79  10368.0  1405.464398  1464.213902
80  10496.0  1411.916134  1460.689565
81  10624.0  1396.051014  1467.058987
82  10752.0  1401.623079  1469.733368
83  10880.0  1398.659544  1482.285795
84  11008.0  1408.951122  1474.462044
85  11136.0  1418.395161  1486.416762
86  11264.0  1414.135126  1483.125812
87  11392.0  1412.345093  1487.860328
88  11520.0  1411.162452  1495.056745
89  11648.0  1419.753509  1498.729929
90  11776.0  1426.601391  1503.525685
91  11904.0  1426.265459  1509.781408
92  12032.0  1415.529415  1512.161779
93  12160.0  1401.842903  1516.788589
94  12288.0  1431.658606  1424.244959
95  12416.0  1432.576769  1394.620911
96  12544.0  1445.248479  1392.938551
97  12672.0  1430.663156  1395.926252
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 31.879 seconds)

Gallery generated by Sphinx-Gallery