Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch', 'naive_softmax'],  # possible values for `line_arg``
        line_names=["Triton", "Torch", "Naive Softmax"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'naive_softmax':
        ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N  Triton (GB/s)  Torch (GB/s)  Naive Softmax (GB/s)
0     256.0     479.392738    688.644471            208.291026
1     384.0     666.701587    819.647163            265.336130
2     512.0     814.381594    911.821022            300.529921
3     640.0     918.974351    915.359612            328.833999
4     768.0     988.334036    974.393732            349.393180
5     896.0    1052.354057   1040.946241            356.511876
6    1024.0    1083.969623   1079.789974            355.806450
7    1152.0    1090.015481   1068.152457            349.438616
8    1280.0    1130.875041   1114.771225            349.595371
9    1408.0    1172.829092   1130.575448            341.536844
10   1536.0    1189.198121   1158.037248            332.885244
11   1664.0    1216.394887   1192.698726            330.042768
12   1792.0    1238.646653   1192.269470            325.922148
13   1920.0    1253.133199   1218.957510            324.107610
14   2048.0    1267.869246   1242.002648            323.808636
15   2176.0    1229.521425    962.382388            325.760834
16   2304.0    1258.626822   1004.627491            326.386060
17   2432.0    1268.672580   1034.881447            327.632528
18   2560.0    1283.067203   1068.026464            328.269200
19   2688.0    1295.403693   1097.625663            329.638462
20   2816.0    1312.918348   1126.770677            329.714936
21   2944.0    1321.249303   1147.475445            331.509081
22   3072.0    1321.962839   1169.654304            334.329260
23   3200.0    1337.790602   1175.557858            334.799862
24   3328.0    1344.825668   1198.484977            336.495734
25   3456.0    1347.881116   1225.761493            336.590062
26   3584.0    1364.855453   1248.444230            338.768782
27   3712.0    1362.249774   1263.894840            340.475992
28   3840.0    1368.319726   1285.649284            340.682410
29   3968.0    1374.910593   1297.324039            341.353283
30   4096.0    1387.623536   1315.103714            339.172059
31   4224.0    1332.575837   1278.375278            343.225698
32   4352.0    1346.440717   1300.796895            345.284246
33   4480.0    1349.759194   1316.351223            346.088020
34   4608.0    1358.267471   1336.244469            346.559639
35   4736.0    1361.059207   1345.500888            348.218048
36   4864.0    1366.759780   1359.102100            349.544293
37   4992.0    1367.526915   1372.345818            350.040602
38   5120.0    1376.970152   1384.510671            351.151968
39   5248.0    1374.523076   1352.943718            351.769267
40   5376.0    1382.906553   1369.164980            352.029528
41   5504.0    1378.327418   1380.601194            354.040175
42   5632.0    1394.310535   1394.658896            353.027318
43   5760.0    1393.677258   1400.080528            355.088660
44   5888.0    1396.089025   1420.363069            355.002988
45   6016.0    1400.837311   1420.025852            356.277630
46   6144.0    1410.312431   1422.664145            356.843455
47   6272.0    1408.401746   1400.618076            358.108065
48   6400.0    1412.474849   1404.704408            358.469040
49   6528.0    1414.434732   1423.880117            359.374819
50   6656.0    1413.966761   1424.257280            359.411236
51   6784.0    1415.473668   1438.572041            360.073442
52   6912.0    1419.738428   1439.471078            360.812988
53   7040.0    1420.920065   1460.479008            361.016527
54   7168.0    1424.308423   1456.577167            361.664017
55   7296.0    1422.336124   1081.533201            362.191231
56   7424.0    1431.017576   1097.218633            362.608483
57   7552.0    1429.993367   1111.797656            363.267944
58   7680.0    1433.983217   1121.882492            363.773756
59   7808.0    1427.894284   1131.042270            364.521322
60   7936.0    1436.483822   1144.706250            364.819339
61   8064.0    1434.042612   1147.319259            365.144132
62   8192.0    1431.391406   1148.758413            363.552395
63   8320.0    1380.422222   1116.801198            361.582812
64   8448.0    1384.886388   1125.678422            362.446951
65   8576.0    1388.245213   1129.163929            363.392506
66   8704.0    1380.456433   1132.834062            364.396059
67   8832.0    1393.213229   1131.710511            364.782700
68   8960.0    1386.567015   1139.804692            365.718730
69   9088.0    1399.766941   1137.059306            366.337985
70   9216.0    1405.970683   1146.345696            367.356037
71   9344.0    1392.151314   1419.416623            367.318290
72   9472.0    1396.291894   1431.596048            368.397985
73   9600.0    1404.701711   1432.142547            368.738328
74   9728.0    1396.771624   1439.911869            369.807814
75   9856.0    1400.908786   1442.263883            369.686348
76   9984.0    1394.004239   1452.396622            370.318004
77  10112.0    1403.358341   1452.022288            371.468629
78  10240.0    1409.026240   1463.730080            371.576925
79  10368.0    1413.428249   1460.566847            370.121370
80  10496.0    1409.937957   1468.626491            370.559601
81  10624.0    1408.787923   1466.562983            371.144105
82  10752.0    1392.216982   1474.168759            371.594674
83  10880.0    1395.957592   1481.201501            372.166852
84  11008.0    1420.387097   1476.240577            372.567191
85  11136.0    1413.925945   1482.428129            373.365592
86  11264.0    1413.352153   1484.408905            373.258931
87  11392.0    1421.908114   1491.101463            374.553622
88  11520.0    1410.223970   1496.615377            374.010968
89  11648.0    1422.333078   1498.795433            375.070260
90  11776.0    1430.976099   1501.197578            374.921690
91  11904.0    1432.826248   1507.580245            375.178277
92  12032.0    1415.458450   1510.852878            375.858488
93  12160.0    1415.097196   1516.700226            376.089869
94  12288.0    1428.788911   1417.846148            375.896896
95  12416.0    1440.006274   1392.656451            374.681248
96  12544.0    1447.985834   1394.750678            375.274802
97  12672.0    1432.899657   1392.639999            375.141306
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 35.022 seconds)

Gallery generated by Sphinx-Gallery