Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch', 'naive_softmax'],  # possible values for `line_arg``
        line_names=["Triton", "Torch", "Naive Softmax"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'naive_softmax':
        ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N  Triton (GB/s)  Torch (GB/s)  Naive Softmax (GB/s)
0     256.0     470.649762    700.817182            206.202414
1     384.0     665.802183    824.408682            260.148190
2     512.0     801.862448    931.228685            302.507488
3     640.0     915.307256    931.417421            330.687754
4     768.0     977.271387    989.579833            350.998878
5     896.0    1040.031950   1030.076948            355.016508
6    1024.0    1071.696803   1068.521748            354.729148
7    1152.0    1093.441164   1074.746394            348.231019
8    1280.0    1131.573182   1111.253166            346.947063
9    1408.0    1161.282231   1132.801877            340.319450
10   1536.0    1186.796074   1168.022114            333.271551
11   1664.0    1214.549436   1190.689087            329.723156
12   1792.0    1233.596056   1191.868127            325.687480
13   1920.0    1265.317331   1218.887082            324.989199
14   2048.0    1270.269777   1246.482213            324.392242
15   2176.0    1239.035963    960.202709            325.310211
16   2304.0    1252.211755   1000.493773            326.253373
17   2432.0    1274.143996   1035.685042            326.855399
18   2560.0    1282.200577   1067.700715            328.022038
19   2688.0    1290.978981   1095.271461            329.596309
20   2816.0    1315.235709   1122.071462            329.133315
21   2944.0    1317.180629   1145.012259            331.550763
22   3072.0    1319.233956   1167.182948            333.001801
23   3200.0    1342.279506   1169.285843            335.116677
24   3328.0    1350.865585   1200.268281            336.388627
25   3456.0    1355.806049   1225.959260            337.489703
26   3584.0    1361.862122   1247.649271            338.532892
27   3712.0    1368.762294   1268.848621            340.884374
28   3840.0    1374.583808   1282.665277            340.752389
29   3968.0    1372.231833   1301.804441            341.054304
30   4096.0    1388.683049   1314.359800            338.761105
31   4224.0    1330.487821   1274.618820            343.511607
32   4352.0    1347.563827   1295.479070            345.695683
33   4480.0    1346.809681   1319.590236            345.818831
34   4608.0    1358.191917   1333.991521            347.080194
35   4736.0    1361.040033   1343.106037            348.184513
36   4864.0    1363.522118   1359.966632            349.167947
37   4992.0    1369.702544   1370.803051            350.456854
38   5120.0    1374.582795   1385.667875            350.887760
39   5248.0    1374.271683   1349.286562            351.359184
40   5376.0    1381.406730   1369.016430            351.586089
41   5504.0    1379.966897   1383.871244            353.329936
42   5632.0    1393.559823   1392.822484            353.259526
43   5760.0    1395.423256   1402.666349            354.744631
44   5888.0    1388.677987   1413.079302            354.702503
45   6016.0    1396.930463   1425.527471            356.857398
46   6144.0    1407.058343   1425.629150            356.866586
47   6272.0    1410.561443   1402.726720            357.890582
48   6400.0    1412.159158   1410.444825            358.608969
49   6528.0    1417.551815   1417.327986            359.257360
50   6656.0    1414.882871   1426.324873            359.195420
51   6784.0    1415.387784   1437.879807            360.371846
52   6912.0    1423.198362   1439.976732            360.794595
53   7040.0    1420.293252   1453.742787            360.699916
54   7168.0    1420.668580   1460.348105            361.495303
55   7296.0    1429.509711   1088.475858            362.761296
56   7424.0    1432.045253   1096.902001            363.012689
57   7552.0    1425.631090   1111.856195            363.792010
58   7680.0    1429.377949   1125.389247            364.238529
59   7808.0    1433.352665   1135.808385            364.239728
60   7936.0    1431.839078   1142.796993            364.887240
61   8064.0    1435.160582   1145.680746            365.003683
62   8192.0    1430.108525   1153.209998            364.236624
63   8320.0    1378.704137   1118.335948            361.938493
64   8448.0    1387.603514   1123.385591            362.765299
65   8576.0    1388.057191   1129.782728            363.227729
66   8704.0    1380.015420   1134.285404            364.699283
67   8832.0    1395.608068   1133.238153            365.182897
68   8960.0    1382.051082   1139.272133            366.208079
69   9088.0    1392.894054   1137.629419            366.840078
70   9216.0    1401.954908   1141.179532            367.945593
71   9344.0    1387.520847   1420.240647            367.559703
72   9472.0    1398.659764   1432.594786            368.914691
73   9600.0    1401.228323   1432.820464            369.084378
74   9728.0    1397.644151   1441.338598            369.656999
75   9856.0    1400.863546   1442.443755            369.801939
76   9984.0    1391.201476   1451.761345            370.389558
77  10112.0    1405.789264   1454.953735            371.251094
78  10240.0    1410.057077   1461.769519            371.918522
79  10368.0    1416.731529   1464.896550            370.001534
80  10496.0    1406.557332   1467.198439            370.790823
81  10624.0    1409.377877   1463.994107            371.201995
82  10752.0    1395.665765   1471.363991            371.260926
83  10880.0    1393.253976   1478.918165            371.626769
84  11008.0    1422.356301   1475.684654            372.899551
85  11136.0    1416.770068   1482.433896            372.819612
86  11264.0    1410.988554   1487.414068            372.598124
87  11392.0    1421.901707   1489.911729            374.107389
88  11520.0    1413.596349   1498.586935            374.116891
89  11648.0    1420.903985   1499.436604            374.564754
90  11776.0    1432.964055   1500.297313            374.677610
91  11904.0    1433.214730   1508.526163            375.329335
92  12032.0    1413.734740   1508.388290            376.199185
93  12160.0    1413.402203   1513.046680            375.624993
94  12288.0    1429.781400   1419.123522            375.285685
95  12416.0    1436.838607   1394.885639            374.672414
96  12544.0    1443.189303   1394.950661            375.546905
97  12672.0    1436.251966   1389.530552            375.246765
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 35.058 seconds)

Gallery generated by Sphinx-Gallery