.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "getting-started/tutorials/02-fused-softmax.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_getting-started_tutorials_02-fused-softmax.py: Fused Softmax ============= In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM. In doing so, you will learn about: * The benefits of kernel fusion for bandwidth-bound operations. * Reduction operators in Triton. .. GENERATED FROM PYTHON SOURCE LINES 18-23 Motivations ----------- Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation: .. GENERATED FROM PYTHON SOURCE LINES 23-62 .. code-block:: Python import torch import triton import triton.language as tl from triton.runtime import driver DEVICE = triton.runtime.driver.active.get_active_torch_device() def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" def is_cdna(): return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908') def naive_softmax(x): """Compute row-wise softmax of X using native pytorch We subtract the maximum element in order to avoid overflows. Softmax is invariant to this shift. """ # read MN elements ; write M elements x_max = x.max(dim=1)[0] # read MN + M elements ; write MN elements z = x - x_max[:, None] # read MN elements ; write MN elements numerator = torch.exp(z) # read MN elements ; write M elements denominator = numerator.sum(dim=1) # read MN + M elements ; write MN elements ret = numerator / denominator[:, None] # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements return ret .. GENERATED FROM PYTHON SOURCE LINES 63-71 When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal. .. GENERATED FROM PYTHON SOURCE LINES 73-82 Compute Kernel -------------- Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes: .. GENERATED FROM PYTHON SOURCE LINES 82-112 .. code-block:: Python @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr): # starting row of the program row_start = tl.program_id(0) row_step = tl.num_programs(0) for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + row_idx * input_row_stride # The block size is the next power of two greater than n_cols, so we can fit each # row in a single block col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols mask = col_offsets < n_cols row = tl.load(input_ptrs, mask=mask, other=-float('inf')) # Subtract maximum for numerical stability row_minus_max = row - tl.max(row, axis=0) # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # Write back output to DRAM output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=mask) .. GENERATED FROM PYTHON SOURCE LINES 113-114 We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. .. GENERATED FROM PYTHON SOURCE LINES 114-178 .. code-block:: Python properties = driver.active.utils.get_device_properties(DEVICE.index) NUM_SM = properties["multiprocessor_count"] NUM_REGS = properties["max_num_regs"] SIZE_SMEM = properties["max_shared_mem"] WARP_SIZE = properties["warpSize"] target = triton.runtime.driver.active.get_current_target() kernels = {} def softmax(x): n_rows, n_cols = x.shape # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` BLOCK_SIZE = triton.next_power_of_2(n_cols) # Another trick we can use is to ask the compiler to use more threads per row by # increasing the number of warps (`num_warps`) over which each row is distributed. # You will see in the next tutorial how to auto-tune this value in a more natural # way so you don't have to come up with manual heuristics yourself. num_warps = 8 # Number of software pipelining stages. num_stages = 4 if SIZE_SMEM > 200000 else 2 # Allocate output y = torch.empty_like(x) # pre-compile kernel to get register usage and compute thread occupancy. kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages, num_warps=num_warps, grid=(1, )) kernel._init_handles() n_regs = kernel.n_regs size_smem = kernel.metadata.shared if is_hip(): # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. # However, this is not always the case. In most cases all registers can be used as regular purpose registers. # ISA SECTION (3.6.4 for CDNA3) # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is # not required to be equal numbers of both types. NUM_GPRS = NUM_REGS if is_cdna(): NUM_GPRS = NUM_REGS * 2 # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. # When we divide this number with WARP_SIZE we get maximum number of waves that can # execute on a CU (multi-processor) in parallel. MAX_NUM_THREADS = properties["max_threads_per_sm"] max_num_waves = MAX_NUM_THREADS // WARP_SIZE occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps else: occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) occupancy = min(occupancy, SIZE_SMEM // size_smem) num_programs = NUM_SM * occupancy num_programs = min(num_programs, n_rows) # Create a number of persistent programs. kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages) return y .. GENERATED FROM PYTHON SOURCE LINES 179-181 Unit Test --------- .. GENERATED FROM PYTHON SOURCE LINES 183-185 We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works. .. GENERATED FROM PYTHON SOURCE LINES 185-192 .. code-block:: Python torch.manual_seed(0) x = torch.randn(1823, 781, device=DEVICE) y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) .. GENERATED FROM PYTHON SOURCE LINES 193-194 As expected, the results are identical. .. GENERATED FROM PYTHON SOURCE LINES 196-201 Benchmark --------- Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. .. GENERATED FROM PYTHON SOURCE LINES 201-232 .. code-block:: Python @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], # argument names to use as an x-axis for the plot x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot line_vals=['triton', 'torch'], # possible values for `line_arg`` line_names=[ "Triton", "Torch", ], # label name for the lines styles=[('blue', '-'), ('green', '-')], # line styles ylabel="GB/s", # label name for the y-axis plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` )) def benchmark(M, N, provider): x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) stream = getattr(torch, DEVICE.type).Stream() getattr(torch, DEVICE.type).set_stream(stream) if provider == 'torch': ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == 'triton': ms = triton.testing.do_bench(lambda: softmax(x)) gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms) benchmark.run(show_plots=True, print_data=True) .. image-sg:: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png :alt: 02 fused softmax :srcset: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none softmax-performance: N Triton Torch 0 256.0 478.015445 709.034104 1 384.0 659.758394 830.660548 2 512.0 809.860714 911.910946 3 640.0 875.556218 954.277007 4 768.0 961.990181 1026.081899 5 896.0 1004.112208 1072.826486 6 1024.0 1057.216943 1120.788525 7 1152.0 1106.995946 1041.940159 8 1280.0 1152.332785 1078.533762 9 1408.0 1154.276903 1108.177598 10 1536.0 1184.029604 1146.467872 11 1664.0 1210.610875 1174.825814 12 1792.0 1230.814345 1191.935135 13 1920.0 1244.043228 1200.797614 14 2048.0 1281.821694 1221.291550 15 2176.0 1245.634132 961.459222 16 2304.0 1253.779504 1000.811075 17 2432.0 1270.216572 1040.155228 18 2560.0 1292.925432 1067.975502 19 2688.0 1304.856171 1100.169317 20 2816.0 1296.832290 1123.732513 21 2944.0 1315.836231 1147.642176 22 3072.0 1326.941755 1168.660416 23 3200.0 1330.567925 1174.079128 24 3328.0 1339.389677 1202.726888 25 3456.0 1361.153381 1228.158204 26 3584.0 1353.508667 1247.211186 27 3712.0 1367.421115 1263.230298 28 3840.0 1376.544264 1285.764654 29 3968.0 1377.578447 1303.521220 30 4096.0 1385.523915 1321.879623 31 4224.0 1332.081126 1297.165790 32 4352.0 1334.016390 1316.176425 33 4480.0 1351.512254 1336.368835 34 4608.0 1360.281020 1356.467527 35 4736.0 1353.900516 1366.203266 36 4864.0 1369.370409 1379.779217 37 4992.0 1377.001972 1392.430741 38 5120.0 1368.030857 1402.700322 39 5248.0 1381.633797 1359.310807 40 5376.0 1374.390156 1390.788940 41 5504.0 1383.480331 1395.273902 42 5632.0 1390.108706 1412.375057 43 5760.0 1394.056479 1420.663372 44 5888.0 1397.393959 1433.609040 45 6016.0 1395.666986 1439.397306 46 6144.0 1404.342854 1441.656327 47 6272.0 1412.749366 1394.506988 48 6400.0 1420.097642 1412.966454 49 6528.0 1409.654117 1418.208078 50 6656.0 1417.666391 1444.620114 51 6784.0 1413.198556 1443.168879 52 6912.0 1419.163270 1453.203722 53 7040.0 1422.690842 1456.711722 54 7168.0 1418.312448 1461.527287 55 7296.0 1426.241611 1085.767764 56 7424.0 1434.706832 1101.362657 57 7552.0 1423.485265 1114.874788 58 7680.0 1435.271256 1126.086625 59 7808.0 1424.027088 1135.961451 60 7936.0 1439.970620 1146.519787 61 8064.0 1433.378034 1154.508441 62 8192.0 1437.552482 1157.084844 63 8320.0 1381.473156 1112.843309 64 8448.0 1377.927465 1124.378700 65 8576.0 1386.827386 1123.408486 66 8704.0 1384.420014 1127.887012 67 8832.0 1383.596422 1127.077808 68 8960.0 1393.050358 1134.272829 69 9088.0 1409.117324 1130.529820 70 9216.0 1392.730319 1128.158661 71 9344.0 1401.549070 1417.530489 72 9472.0 1394.007226 1432.136094 73 9600.0 1398.054839 1432.775074 74 9728.0 1405.499480 1439.619960 75 9856.0 1398.343471 1436.939450 76 9984.0 1402.627485 1451.493446 77 10112.0 1399.710196 1450.999666 78 10240.0 1425.728783 1464.423188 79 10368.0 1408.222154 1459.641740 80 10496.0 1417.455554 1465.236925 81 10624.0 1399.647765 1465.853797 82 10752.0 1414.160330 1472.725393 83 10880.0 1406.950003 1477.499062 84 11008.0 1413.594622 1475.530862 85 11136.0 1421.138489 1482.988465 86 11264.0 1426.172604 1482.980596 87 11392.0 1421.283123 1486.652031 88 11520.0 1419.858103 1496.890664 89 11648.0 1436.458508 1500.809980 90 11776.0 1434.462395 1504.885634 91 11904.0 1436.552667 1506.886162 92 12032.0 1422.918947 1511.602291 93 12160.0 1406.337871 1513.890887 94 12288.0 1440.160689 1424.359076 95 12416.0 1442.886808 1399.404295 96 12544.0 1453.214654 1391.525906 97 12672.0 1439.462920 1395.235997 .. GENERATED FROM PYTHON SOURCE LINES 233-237 In the above plot, we can see that: - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 23.433 seconds) .. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 02-fused-softmax.ipynb <02-fused-softmax.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 02-fused-softmax.py <02-fused-softmax.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 02-fused-softmax.zip <02-fused-softmax.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_