Fused Softmax

In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch’s native op for a particular class of matrices: those whose rows can fit in the GPU’s SRAM.

In doing so, you will learn about:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

Custom GPU kernels for elementwise additions are educationally valuable but won’t get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation:

import torch

import triton
import triton.language as tl
from triton.runtime import driver

DEVICE = triton.runtime.driver.active.get_active_torch_device()


def is_hip():
    return triton.runtime.driver.active.get_current_target().backend == "hip"


def is_cdna():
    return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942',
                                                                                   'gfx90a', 'gfx908')


def naive_softmax(x):
    """Compute row-wise softmax of X using native pytorch

    We subtract the maximum element in order to avoid overflows. Softmax is invariant to
    this shift.
    """
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read MN + M elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(z)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read MN + M elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements
    return ret

When implemented naively in PyTorch, computing y = naive_softmax(x) for \(x \in R^{M \times N}\) requires reading \(5MN + 2M\) elements from DRAM and writing back \(3MN + 2M\) elements. This is obviously wasteful; we’d prefer to have a custom “fused” kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only \(MN\) bytes, so we could expect a theoretical speed-up of ~4x (i.e., \((8MN + 4M) / 2MN\)). The torch.jit.script flags aims to perform this kind of “kernel fusion” automatically but, as we will see later, it is still far from ideal.

Compute Kernel

Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

@triton.jit
def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr,
                   num_stages: tl.constexpr):
    # starting row of the program
    row_start = tl.program_id(0)
    row_step = tl.num_programs(0)
    for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages):
        # The stride represents how much we need to increase the pointer to advance 1 row
        row_start_ptr = input_ptr + row_idx * input_row_stride
        # The block size is the next power of two greater than n_cols, so we can fit each
        # row in a single block
        col_offsets = tl.arange(0, BLOCK_SIZE)
        input_ptrs = row_start_ptr + col_offsets
        # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
        mask = col_offsets < n_cols
        row = tl.load(input_ptrs, mask=mask, other=-float('inf'))
        # Subtract maximum for numerical stability
        row_minus_max = row - tl.max(row, axis=0)
        # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)
        numerator = tl.exp(row_minus_max)
        denominator = tl.sum(numerator, axis=0)
        softmax_output = numerator / denominator
        # Write back output to DRAM
        output_row_start_ptr = output_ptr + row_idx * output_row_stride
        output_ptrs = output_row_start_ptr + col_offsets
        tl.store(output_ptrs, softmax_output, mask=mask)

We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor.

properties = driver.active.utils.get_device_properties(DEVICE.index)
NUM_SM = properties["multiprocessor_count"]
NUM_REGS = properties["max_num_regs"]
SIZE_SMEM = properties["max_shared_mem"]
WARP_SIZE = properties["warpSize"]
target = triton.runtime.driver.active.get_current_target()
kernels = {}


def softmax(x):
    n_rows, n_cols = x.shape

    # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x`
    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 8

    # Number of software pipelining stages.
    num_stages = 4 if SIZE_SMEM > 200000 else 2

    # Allocate output
    y = torch.empty_like(x)

    # pre-compile kernel to get register usage and compute thread occupancy.
    kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE,
                                   num_stages=num_stages, num_warps=num_warps, grid=(1, ))
    kernel._init_handles()
    n_regs = kernel.n_regs
    size_smem = kernel.metadata.shared
    if is_hip():
        # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available.
        # However, this is not always the case. In most cases all registers can be used as regular purpose registers.
        # ISA SECTION (3.6.4 for CDNA3)
        # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used
        # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total
        # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is
        # not required to be equal numbers of both types.
        NUM_GPRS = NUM_REGS
        if is_cdna():
            NUM_GPRS = NUM_REGS * 2

        # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor.
        # When we divide this number with WARP_SIZE we get maximum number of waves that can
        # execute on a CU (multi-processor)  in parallel.
        MAX_NUM_THREADS = properties["max_threads_per_sm"]
        max_num_waves = MAX_NUM_THREADS // WARP_SIZE
        occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps
    else:
        occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps)
    occupancy = min(occupancy, SIZE_SMEM // size_smem)
    num_programs = NUM_SM * occupancy

    num_programs = min(num_programs, n_rows)

    # Create a number of persistent programs.
    kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages)
    return y

Unit Test

We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works.

torch.manual_seed(0)
x = torch.randn(1823, 781, device=DEVICE)
y_triton = softmax(x)
y_torch = torch.softmax(x, axis=1)
assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)

As expected, the results are identical.

Benchmark

Here we will benchmark our operation as a function of the number of columns in the input matrix – assuming 4096 rows. We will then compare its performance against (1) torch.softmax and (2) the naive_softmax defined above.

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch'],  # possible values for `line_arg``
        line_names=[
            "Triton",
            "Torch",
        ],  # label name for the lines
        styles=[('blue', '-'), ('green', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == 'torch':
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: softmax(x))
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax
softmax-performance:
          N       Triton        Torch
0     256.0   470.978332   701.687163
1     384.0   657.046463   823.769029
2     512.0   807.623624   914.109169
3     640.0   882.785457   947.893530
4     768.0   966.310849  1020.015563
5     896.0  1016.609970  1062.339363
6    1024.0  1066.900700  1124.095905
7    1152.0  1108.738464  1030.960172
8    1280.0  1138.822542  1069.541463
9    1408.0  1164.249191  1105.239250
10   1536.0  1179.214278  1133.391458
11   1664.0  1206.280581  1172.741617
12   1792.0  1229.846580  1191.847293
13   1920.0  1250.989273  1193.521627
14   2048.0  1279.660788  1224.995187
15   2176.0  1246.470077   963.575395
16   2304.0  1250.444109  1002.801480
17   2432.0  1274.255923  1042.264524
18   2560.0  1287.134793  1070.865953
19   2688.0  1296.740664  1101.098457
20   2816.0  1299.414948  1121.499955
21   2944.0  1323.460189  1150.705621
22   3072.0  1328.255127  1171.606015
23   3200.0  1336.806082  1179.513950
24   3328.0  1346.874872  1203.505746
25   3456.0  1352.139083  1221.506282
26   3584.0  1348.640976  1247.681154
27   3712.0  1362.861991  1264.200202
28   3840.0  1376.467408  1283.433035
29   3968.0  1371.757957  1301.804441
30   4096.0  1385.874169  1316.756551
31   4224.0  1337.132338  1295.909618
32   4352.0  1337.910554  1319.074892
33   4480.0  1352.801533  1333.709298
34   4608.0  1360.171779  1356.057065
35   4736.0  1361.794351  1367.855576
36   4864.0  1376.078475  1384.464041
37   4992.0  1376.901102  1396.357132
38   5120.0  1373.494378  1405.912769
39   5248.0  1378.435147  1367.206376
40   5376.0  1378.126308  1382.576327
41   5504.0  1387.285763  1392.836125
42   5632.0  1388.112728  1409.269053
43   5760.0  1396.261620  1423.088800
44   5888.0  1395.452816  1425.882295
45   6016.0  1396.320079  1433.370366
46   6144.0  1401.437664  1438.605996
47   6272.0  1414.430063  1410.725117
48   6400.0  1416.906908  1419.780491
49   6528.0  1411.271244  1432.799779
50   6656.0  1413.156172  1443.615925
51   6784.0  1413.261177  1451.404443
52   6912.0  1424.422773  1452.297443
53   7040.0  1421.637812  1454.983874
54   7168.0  1417.603576  1461.404738
55   7296.0  1429.798839  1084.869735
56   7424.0  1431.221102  1100.295661
57   7552.0  1425.874625  1113.697017
58   7680.0  1435.798310  1126.375496
59   7808.0  1425.099180  1135.422377
60   7936.0  1435.708060  1144.793951
61   8064.0  1435.586981  1153.504373
62   8192.0  1436.923848  1156.269126
63   8320.0  1379.171171  1114.501104
64   8448.0  1379.030784  1123.865804
65   8576.0  1390.105003  1124.515235
66   8704.0  1380.243375  1129.949517
67   8832.0  1381.072715  1130.295460
68   8960.0  1394.414844  1136.036844
69   9088.0  1410.892251  1130.223800
70   9216.0  1397.362322  1128.744948
71   9344.0  1400.354490  1421.210205
72   9472.0  1393.530763  1433.360631
73   9600.0  1397.761406  1431.225555
74   9728.0  1405.523492  1437.142920
75   9856.0  1396.429582  1440.369504
76   9984.0  1401.238070  1444.727769
77  10112.0  1401.803841  1452.943359
78  10240.0  1423.042407  1466.950202
79  10368.0  1409.037797  1462.317338
80  10496.0  1415.249472  1463.784458
81  10624.0  1400.139297  1464.812037
82  10752.0  1404.510763  1472.090890
83  10880.0  1404.918279  1477.783502
84  11008.0  1411.919631  1479.354753
85  11136.0  1424.070611  1480.523601
86  11264.0  1421.457930  1487.843155
87  11392.0  1420.358185  1485.709549
88  11520.0  1421.296138  1496.295659
89  11648.0  1432.400225  1499.858134
90  11776.0  1435.562742  1501.171293
91  11904.0  1437.085261  1509.589623
92  12032.0  1424.839440  1510.485480
93  12160.0  1403.574468  1513.071395
94  12288.0  1440.074017  1424.375949
95  12416.0  1441.300454  1396.538354
96  12544.0  1453.757131  1389.912639
97  12672.0  1438.051017  1393.591433
In the above plot, we can see that:
  • Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here.

  • Triton is noticeably faster than torch.softmax – in addition to being easier to read, understand and maintain. Note however that the PyTorch softmax operation is more general and will work on tensors of any shape.

Total running time of the script: (0 minutes 23.372 seconds)

Gallery generated by Sphinx-Gallery