Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch'],  # possible values for `line_arg``
        line_names=[
            "Triton",
            "Torch",
        ],  # label name for the lines
        styles=[('blue', '-'), ('green', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch
0     256.0   466.364565   697.334601
1     384.0   650.818893   817.308710
2     512.0   800.797842   911.754582
3     640.0   802.169752   959.639932
4     768.0   882.094429  1017.303340
5     896.0   937.070856  1074.489355
6    1024.0  1009.039332  1107.706231
7    1152.0  1104.272729  1034.414615
8    1280.0  1142.811903  1078.972327
9    1408.0  1166.462589  1113.068593
10   1536.0  1184.357741  1139.453928
11   1664.0  1209.753880  1159.962015
12   1792.0  1238.147763  1195.260716
13   1920.0  1251.237712  1197.260076
14   2048.0  1270.146597  1227.116459
15   2176.0  1243.839878   958.276887
16   2304.0  1240.547606  1000.883769
17   2432.0  1266.856304  1032.416770
18   2560.0  1285.268236  1063.677866
19   2688.0  1289.105210  1096.634250
20   2816.0  1292.994410  1120.523157
21   2944.0  1307.737326  1146.880018
22   3072.0  1325.606229  1167.525615
23   3200.0  1325.029087  1175.538624
24   3328.0  1342.840232  1202.726888
25   3456.0  1352.594982  1224.965486
26   3584.0  1346.724265  1245.554565
27   3712.0  1366.991436  1264.789115
28   3840.0  1369.289856  1285.408614
29   3968.0  1372.231831  1295.605875
30   4096.0  1372.481667  1317.170795
31   4224.0  1334.927257  1290.115436
32   4352.0  1335.387621  1315.784686
33   4480.0  1349.596728  1333.533314
34   4608.0  1359.269568  1351.233157
35   4736.0  1358.152306  1366.726616
36   4864.0  1378.601053  1382.478386
37   4992.0  1368.874500  1392.486813
38   5120.0  1374.322355  1406.676849
39   5248.0  1376.228596  1357.665507
40   5376.0  1379.030189  1384.667838
41   5504.0  1381.412849  1393.023666
42   5632.0  1384.807598  1409.037237
43   5760.0  1390.220001  1422.183301
44   5888.0  1389.242411  1434.552412
45   6016.0  1396.127682  1433.920012
46   6144.0  1413.288506  1447.100857
47   6272.0  1411.016265  1407.094404
48   6400.0  1415.103857  1421.164796
49   6528.0  1413.242316  1427.466102
50   6656.0  1422.046727  1433.858745
51   6784.0  1414.958636  1455.001110
52   6912.0  1429.146988  1458.376153
53   7040.0  1418.504192  1460.792465
54   7168.0  1428.515456  1467.977055
55   7296.0  1430.681425  1085.766664
56   7424.0  1427.078482  1101.834555
57   7552.0  1426.245465  1113.776095
58   7680.0  1437.368097  1127.952688
59   7808.0  1432.683862  1136.157330
60   7936.0  1436.643328  1145.693467
61   8064.0  1439.171712  1151.663838
62   8192.0  1438.339987  1156.014178
63   8320.0  1397.822447  1113.864384
64   8448.0  1382.990073  1123.100066
65   8576.0  1398.372849  1125.691611
66   8704.0  1395.704378  1126.867959
67   8832.0  1386.248120  1128.248706
68   8960.0  1401.473910  1135.203613
69   9088.0  1413.768491  1131.505463
70   9216.0  1407.978843  1127.626824
71   9344.0  1411.909147  1423.134257
72   9472.0  1406.283856  1430.488357
73   9600.0  1401.760796  1432.104370
74   9728.0  1407.178381  1440.650606
75   9856.0  1421.384574  1441.313389
76   9984.0  1402.296627  1448.714441
77  10112.0  1412.858505  1451.885339
78  10240.0  1419.102508  1462.548677
79  10368.0  1416.232210  1463.111199
80  10496.0  1419.267579  1463.917231
81  10624.0  1417.292944  1464.452488
82  10752.0  1410.445695  1473.632699
83  10880.0  1408.355557  1477.357192
84  11008.0  1426.457127  1478.589244
85  11136.0  1427.076958  1481.097270
86  11264.0  1432.259854  1488.126310
87  11392.0  1419.375491  1489.392703
88  11520.0  1428.307676  1497.033987
89  11648.0  1429.634207  1500.164526
90  11776.0  1436.568935  1503.278345
91  11904.0  1444.700931  1509.006937
92  12032.0  1429.864182  1508.444236
93  12160.0  1427.230047  1516.440878
94  12288.0  1440.351650  1422.074616
95  12416.0  1457.126376  1397.884923
96  12544.0  1445.445651  1394.588480
97  12672.0  1457.101641  1393.874664
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 23.353 seconds)

Gallery generated by Sphinx-Gallery