Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch', 'naive_softmax'],  # possible values for `line_arg``
        line_names=["Triton", "Torch", "Naive Softmax"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'naive_softmax':
        ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N  Triton (GB/s)  Torch (GB/s)  Naive Softmax (GB/s)
0     256.0     500.056313    691.626019            203.192830
1     384.0     696.062625    813.845754            260.071501
2     512.0     835.022220    918.642249            300.099694
3     640.0     837.570612    914.495068            331.498427
4     768.0     910.458857    990.967742            351.839647
5     896.0     967.244058   1029.561299            355.314277
6    1024.0    1030.878009   1081.342046            353.232439
7    1152.0    1027.968971   1073.265719            347.298383
8    1280.0    1071.140288   1101.689922            347.551781
9    1408.0    1113.470090   1134.705382            342.451459
10   1536.0    1145.985502   1154.428188            332.555958
11   1664.0    1188.469599   1181.463388            328.990865
12   1792.0    1203.988488   1200.802825            325.804163
13   1920.0    1237.792045   1219.498761            326.324876
14   2048.0    1249.280876   1244.770839            325.679763
15   2176.0    1180.409644    960.102574            325.528351
16   2304.0    1197.911336   1004.126619            326.041386
17   2432.0    1220.757470   1038.762584            327.458095
18   2560.0    1236.683111   1067.281502            327.915274
19   2688.0    1260.046411   1096.970480            329.348194
20   2816.0    1278.139925   1122.202008            328.969537
21   2944.0    1293.466512   1148.473081            331.520087
22   3072.0    1313.583386   1174.572486            333.644160
23   3200.0    1326.784009   1168.825188            335.086250
24   3328.0    1326.725616   1198.894453            336.854932
25   3456.0    1338.935214   1223.314109            337.391154
26   3584.0    1341.160594   1247.464691            338.296674
27   3712.0    1336.926377   1263.590072            340.288237
28   3840.0    1361.291851   1286.526136            340.584794
29   3968.0    1362.983610   1295.021825            341.080402
30   4096.0    1368.590779   1315.977071            338.893680
31   4224.0    1337.126543   1274.053960            342.404101
32   4352.0    1345.669905   1299.844559            344.947548
33   4480.0    1358.561032   1318.110175            345.802896
34   4608.0    1366.813941   1335.028025            347.393896
35   4736.0    1361.060790   1344.125600            348.439905
36   4864.0    1379.597313   1356.449819            349.398080
37   4992.0    1369.783555   1369.853953            350.438165
38   5120.0    1387.651729   1388.767977            351.257295
39   5248.0    1382.438090   1357.973029            351.840500
40   5376.0    1381.536173   1362.970925            351.367761
41   5504.0    1389.537390   1382.731570            353.535239
42   5632.0    1400.149133   1398.331249            353.445511
43   5760.0    1403.211451   1405.312739            355.014219
44   5888.0    1397.281267   1415.510248            354.841607
45   6016.0    1408.297269   1417.921250            356.919696
46   6144.0    1416.889903   1431.642685            357.348429
47   6272.0    1412.015045   1403.142553            357.516387
48   6400.0    1416.556478   1409.997479            358.622822
49   6528.0    1418.322293   1424.740618            359.079778
50   6656.0    1418.749015   1436.639885            359.062388
51   6784.0    1425.013066   1432.909455            360.045417
52   6912.0    1430.740630   1452.567868            360.500573
53   7040.0    1424.275770   1454.312962            361.066645
54   7168.0    1425.342721   1469.095332            361.673141
55   7296.0    1431.835618   1088.316987            362.983241
56   7424.0    1430.429088   1100.058589            362.830912
57   7552.0    1433.148163   1111.496335            363.696198
58   7680.0    1434.464478   1123.085351            363.381741
59   7808.0    1433.595854   1135.670435            364.662283
60   7936.0    1435.272831   1144.102307            364.645772
61   8064.0    1440.980920   1150.457259            365.252940
62   8192.0    1431.998891   1152.154218            363.737060
63   8320.0    1391.066545   1114.616869            361.288413
64   8448.0    1381.762661   1126.645098            362.376240
65   8576.0    1389.555569   1127.866061            363.174319
66   8704.0    1385.894967   1135.795364            364.422793
67   8832.0    1390.679493   1132.780439            364.831566
68   8960.0    1386.789222   1140.230247            365.972137
69   9088.0    1401.070158   1138.054658            366.571799
70   9216.0    1404.422469   1141.420110            367.618343
71   9344.0    1394.129021   1423.721125            367.206633
72   9472.0    1405.002507   1431.404564            368.566525
73   9600.0    1397.438979   1427.451848            368.975432
74   9728.0    1397.824134   1439.531795            369.550616
75   9856.0    1398.684989   1440.355758            369.952433
76   9984.0    1391.374582   1448.018068            370.367195
77  10112.0    1405.495863   1454.754718            371.513056
78  10240.0    1405.140854   1465.643055            371.653758
79  10368.0    1412.885493   1462.872500            369.770961
80  10496.0    1403.994888   1467.297766            370.448541
81  10624.0    1404.777868   1467.325354            370.850471
82  10752.0    1399.186122   1471.662808            371.400726
83  10880.0    1397.558453   1478.554339            371.979058
84  11008.0    1420.242113   1479.438072            371.935129
85  11136.0    1424.050606   1485.278001            373.290036
86  11264.0    1414.994760   1488.049081            373.378928
87  11392.0    1419.858756   1487.304962            374.394446
88  11520.0    1407.667639   1493.888328            374.226985
89  11648.0    1415.753170   1501.613686            374.290402
90  11776.0    1431.472624   1502.882488            375.250591
91  11904.0    1431.230552   1510.543154            376.335801
92  12032.0    1425.182627   1508.575051            376.327660
93  12160.0    1420.287355   1516.435817            376.533679
94  12288.0    1429.188170   1421.846300            376.538108
95  12416.0    1432.113344   1393.922980            374.425254
96  12544.0    1440.117980   1396.146810            375.209031
97  12672.0    1437.111999   1396.447843            375.405073
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 40.100 seconds)

Gallery generated by Sphinx-Gallery