Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch', 'naive_softmax'],  # possible values for `line_arg``
        line_names=["Triton", "Torch", "Naive Softmax"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('red', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'naive_softmax':
        ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N  Triton (GB/s)  Torch (GB/s)  Naive Softmax (GB/s)
0     256.0     514.264267    710.308844            205.684566
1     384.0     709.130739    820.099233            262.047238
2     512.0     831.113104    926.637333            302.423697
3     640.0     845.012367    910.118495            331.453903
4     768.0     920.260997    984.541553            351.905361
5     896.0     968.800342   1038.174820            355.361785
6    1024.0    1023.589910   1075.672814            354.399517
7    1152.0    1027.876081   1068.003854            348.679859
8    1280.0    1082.334979   1109.563610            348.234217
9    1408.0    1124.233870   1136.594676            343.090526
10   1536.0    1146.434784   1165.091116            333.637247
11   1664.0    1180.542121   1178.782176            329.831516
12   1792.0    1202.236290   1200.369240            326.294104
13   1920.0    1234.823545   1226.698613            326.392908
14   2048.0    1250.537876   1246.703012            325.807920
15   2176.0    1185.344056    964.733820            326.001449
16   2304.0    1197.742271   1004.568615            326.107284
17   2432.0    1220.500784   1035.553574            327.270171
18   2560.0    1245.425385   1071.372786            328.644020
19   2688.0    1268.226957   1097.909776            329.540370
20   2816.0    1276.486904   1127.178672            329.376307
21   2944.0    1293.509830   1144.780139            331.862932
22   3072.0    1308.364934   1171.347845            334.144493
23   3200.0    1323.065987   1170.899844            334.975634
24   3328.0    1321.977898   1199.030711            336.250247
25   3456.0    1344.006019   1220.282338            337.369985
26   3584.0    1348.584291   1244.280350            338.318047
27   3712.0    1350.214641   1265.563111            340.154016
28   3840.0    1363.937470   1281.479263            341.097767
29   3968.0    1369.911621   1296.414309            341.116968
30   4096.0    1372.520927   1320.005111            339.147794
31   4224.0    1337.937950   1275.396837            342.882778
32   4352.0    1344.612825   1298.449218            344.933851
33   4480.0    1356.305577   1314.354152            346.108963
34   4608.0    1368.563420   1335.764560            346.683385
35   4736.0    1363.832377   1344.591131            347.952697
36   4864.0    1381.945394   1356.335887            349.543536
37   4992.0    1377.411409   1368.175999            350.511724
38   5120.0    1386.624547   1384.651024            351.021800
39   5248.0    1382.458606   1357.826632            352.076863
40   5376.0    1383.900399   1368.195770            351.716787
41   5504.0    1390.987944   1379.897300            353.628359
42   5632.0    1398.559211   1399.547777            353.706218
43   5760.0    1405.556787   1405.382160            354.606333
44   5888.0    1392.290556   1412.974571            354.832389
45   6016.0    1405.294169   1423.245445            356.396805
46   6144.0    1417.238568   1432.749175            356.889724
47   6272.0    1416.502069   1398.574706            357.848966
48   6400.0    1417.938216   1411.810886            358.821483
49   6528.0    1424.625112   1424.500724            359.517905
50   6656.0    1421.624639   1436.761401            359.348221
51   6784.0    1423.619997   1443.753304            360.206259
52   6912.0    1433.749533   1447.611932            360.675091
53   7040.0    1423.930225   1443.871189            361.002863
54   7168.0    1431.701915   1460.786881            361.755280
55   7296.0    1433.029000   1086.976665            362.956860
56   7424.0    1436.842862   1096.046149            362.599411
57   7552.0    1436.033943   1111.040374            363.573085
58   7680.0    1433.409213   1119.168489            363.586455
59   7808.0    1429.387914   1134.193956            364.789697
60   7936.0    1441.907121   1144.670112            364.986581
61   8064.0    1439.699457   1151.941119            365.048978
62   8192.0    1429.125916   1148.652331            363.692001
63   8320.0    1389.156504   1114.892884            361.292870
64   8448.0    1385.946510   1122.729839            362.158007
65   8576.0    1391.605442   1126.640472            363.071994
66   8704.0    1388.162918   1129.806384            364.084440
67   8832.0    1394.561608   1127.800743            364.693888
68   8960.0    1385.069692   1136.568658            365.749829
69   9088.0    1401.081398   1134.905847            366.804488
70   9216.0    1402.843751   1138.959660            367.303125
71   9344.0    1393.951105   1418.160746            367.653672
72   9472.0    1406.857297   1433.454705            368.515502
73   9600.0    1393.235044   1429.553272            368.760685
74   9728.0    1401.458065   1441.288254            369.683333
75   9856.0    1396.620020   1441.643412            369.945633
76   9984.0    1391.710160   1447.342120            370.201790
77  10112.0    1408.177239   1454.496136            371.450252
78  10240.0    1408.013746   1463.856631            371.765898
79  10368.0    1412.831752   1464.424151            370.076976
80  10496.0    1410.730869   1464.739874            370.413013
81  10624.0    1407.995487   1467.148953            370.872699
82  10752.0    1401.726694   1473.363476            371.501446
83  10880.0    1394.905228   1480.363461            371.858680
84  11008.0    1416.782862   1477.417634            372.979405
85  11136.0    1424.858054   1486.373701            372.988118
86  11264.0    1413.743334   1485.504816            373.405603
87  11392.0    1424.490259   1488.262365            374.089737
88  11520.0    1415.587495   1494.503457            373.667128
89  11648.0    1417.579578   1498.980549            374.409829
90  11776.0    1433.665511   1500.585282            374.637699
91  11904.0    1432.593872   1509.243634            375.294065
92  12032.0    1429.695313   1508.840525            376.420749
93  12160.0    1422.011454   1517.582988            375.855074
94  12288.0    1431.659257   1418.086413            375.914640
95  12416.0    1433.010428   1393.989466            374.006705
96  12544.0    1443.285842   1393.493167            375.195879
97  12672.0    1436.689774   1392.959954            375.022730
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 37.341 seconds)

Gallery generated by Sphinx-Gallery