Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch', 'naive_softmax'],  # possible values for `line_arg``
        line_names=["Triton", "Torch", "Naive Softmax"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'naive_softmax':
        ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N  Triton (GB/s)  Torch (GB/s)  Naive Softmax (GB/s)
0     256.0     479.781182    688.625207            208.987586
1     384.0     657.673842    815.728814            263.644206
2     512.0     805.414257    927.425586            299.593141
3     640.0     911.149048    918.399745            329.779601
4     768.0     988.283253    989.660727            348.472431
5     896.0    1053.691792   1037.975260            353.985226
6    1024.0    1080.694808   1067.017140            352.434521
7    1152.0    1098.452974   1075.291077            348.620721
8    1280.0    1137.718677   1099.305531            348.750081
9    1408.0    1169.334802   1137.996755            339.954948
10   1536.0    1185.699169   1157.250752            333.555329
11   1664.0    1212.157901   1193.285535            330.303689
12   1792.0    1228.453606   1193.109986            326.204874
13   1920.0    1252.607669   1225.363032            324.517779
14   2048.0    1276.729177   1243.581848            324.087924
15   2176.0    1236.741339    959.130061            325.771214
16   2304.0    1252.735515   1004.425475            326.114892
17   2432.0    1265.407257   1034.088036            325.984479
18   2560.0    1281.347128   1071.737649            328.055721
19   2688.0    1295.767387   1100.693101            329.109278
20   2816.0    1314.557607   1126.034001            329.184528
21   2944.0    1317.774173   1143.547856            331.630184
22   3072.0    1325.597177   1170.120676            333.249958
23   3200.0    1341.818311   1174.498366            334.323162
24   3328.0    1344.092006   1196.551544            336.119174
25   3456.0    1349.377115   1223.790312            337.011476
26   3584.0    1358.690087   1239.514034            337.985459
27   3712.0    1369.597754   1261.849190            340.105162
28   3840.0    1372.722466   1282.793638            340.247436
29   3968.0    1377.408638   1296.502510            340.819870
30   4096.0    1390.341948   1315.666402            338.885231
31   4224.0    1326.133645   1279.086809            343.011886
32   4352.0    1340.593447   1295.210272            345.225582
33   4480.0    1348.797177   1318.509138            346.260399
34   4608.0    1356.320739   1336.182709            346.586377
35   4736.0    1361.628777   1343.964607            347.939725
36   4864.0    1365.299096   1356.739776            348.838700
37   4992.0    1370.849053   1373.882160            349.871840
38   5120.0    1372.581946   1384.780686            350.334995
39   5248.0    1376.758297   1353.924955            351.914773
40   5376.0    1380.467876   1368.374138            351.418178
41   5504.0    1383.996045   1379.726486            353.600149
42   5632.0    1393.321318   1393.712962            353.208412
43   5760.0    1395.118355   1405.432369            354.953758
44   5888.0    1391.266428   1417.560204            354.915940
45   6016.0    1402.692226   1419.285124            356.508356
46   6144.0    1409.565803   1427.441797            357.176855
47   6272.0    1408.761964   1397.427764            357.493315
48   6400.0    1410.296723   1411.735196            357.995939
49   6528.0    1412.017019   1412.889957            358.872617
50   6656.0    1413.653525   1431.494328            359.323961
51   6784.0    1418.629389   1433.680566            360.233845
52   6912.0    1421.956900   1440.619506            360.633741
53   7040.0    1420.684692   1449.915226            360.798001
54   7168.0    1422.670018   1455.174501            362.035869
55   7296.0    1423.580122   1082.049192            362.296326
56   7424.0    1427.298193   1096.763620            362.772475
57   7552.0    1426.841788   1108.034556            363.213348
58   7680.0    1430.103091   1122.907699            364.016380
59   7808.0    1431.374146   1128.539145            364.298735
60   7936.0    1432.964422   1139.890770            364.209214
61   8064.0    1434.165083   1147.524841            364.610096
62   8192.0    1433.908065   1148.726968            363.842927
63   8320.0    1380.826916   1115.978689            361.770409
64   8448.0    1381.316460   1123.545476            362.889457
65   8576.0    1389.308485   1123.863512            363.468264
66   8704.0    1381.412793   1132.992060            364.302519
67   8832.0    1396.614561   1131.225215            364.853782
68   8960.0    1388.630961   1138.040113            365.598816
69   9088.0    1396.935100   1135.232854            366.608855
70   9216.0    1405.272267   1141.438830            367.532000
71   9344.0    1387.280022   1421.428375            367.681536
72   9472.0    1396.486184   1427.347615            368.815180
73   9600.0    1400.616524   1431.691084            368.787515
74   9728.0    1398.398902   1443.438327            369.563910
75   9856.0    1399.888402   1440.888556            370.435724
76   9984.0    1393.050699   1451.403084            370.835312
77  10112.0    1402.376974   1456.504987            371.891105
78  10240.0    1409.858981   1463.827179            371.555127
79  10368.0    1415.596316   1461.388737            370.085854
80  10496.0    1413.322769   1468.165544            370.497401
81  10624.0    1408.137115   1463.134142            370.774913
82  10752.0    1397.303252   1472.617419            371.418349
83  10880.0    1394.759547   1481.661099            371.881936
84  11008.0    1422.446060   1479.103316            372.589329
85  11136.0    1419.561328   1486.240756            373.401159
86  11264.0    1410.283620   1485.205713            373.147894
87  11392.0    1422.751370   1490.516449            374.385608
88  11520.0    1410.305717   1499.644971            374.333342
89  11648.0    1422.099167   1499.055070            375.003669
90  11776.0    1438.241289   1501.677608            375.638014
91  11904.0    1432.181417   1508.180933            375.156017
92  12032.0    1417.078289   1508.179311            375.920386
93  12160.0    1413.788364   1516.435822            376.404866
94  12288.0    1426.985668   1421.984003            375.822737
95  12416.0    1436.504123   1400.556055            374.880106
96  12544.0    1444.269509   1394.610410            375.441527
97  12672.0    1437.226001   1394.385728            375.158875
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 35.006 seconds)

Gallery generated by Sphinx-Gallery