.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "getting-started/tutorials/02-fused-softmax.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_getting-started_tutorials_02-fused-softmax.py: Fused Softmax ============= In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM. In doing so, you will learn about: * The benefits of kernel fusion for bandwidth-bound operations. * Reduction operators in Triton. .. GENERATED FROM PYTHON SOURCE LINES 18-23 Motivations ----------- Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation: .. GENERATED FROM PYTHON SOURCE LINES 23-62 .. code-block:: Python import torch import triton import triton.language as tl from triton.runtime import driver DEVICE = triton.runtime.driver.active.get_active_torch_device() def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" def is_cdna(): return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908') def naive_softmax(x): """Compute row-wise softmax of X using native pytorch We subtract the maximum element in order to avoid overflows. Softmax is invariant to this shift. """ # read MN elements ; write M elements x_max = x.max(dim=1)[0] # read MN + M elements ; write MN elements z = x - x_max[:, None] # read MN elements ; write MN elements numerator = torch.exp(z) # read MN elements ; write M elements denominator = numerator.sum(dim=1) # read MN + M elements ; write MN elements ret = numerator / denominator[:, None] # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements return ret .. GENERATED FROM PYTHON SOURCE LINES 63-71 When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal. .. GENERATED FROM PYTHON SOURCE LINES 73-82 Compute Kernel -------------- Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes: .. GENERATED FROM PYTHON SOURCE LINES 82-112 .. code-block:: Python @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr): # starting row of the program row_start = tl.program_id(0) row_step = tl.num_programs(0) for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + row_idx * input_row_stride # The block size is the next power of two greater than n_cols, so we can fit each # row in a single block col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols mask = col_offsets < n_cols row = tl.load(input_ptrs, mask=mask, other=-float('inf')) # Subtract maximum for numerical stability row_minus_max = row - tl.max(row, axis=0) # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # Write back output to DRAM output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=mask) .. GENERATED FROM PYTHON SOURCE LINES 113-114 We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. .. GENERATED FROM PYTHON SOURCE LINES 114-178 .. code-block:: Python properties = driver.active.utils.get_device_properties(DEVICE.index) NUM_SM = properties["multiprocessor_count"] NUM_REGS = properties["max_num_regs"] SIZE_SMEM = properties["max_shared_mem"] WARP_SIZE = properties["warpSize"] target = triton.runtime.driver.active.get_current_target() kernels = {} def softmax(x): n_rows, n_cols = x.shape # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` BLOCK_SIZE = triton.next_power_of_2(n_cols) # Another trick we can use is to ask the compiler to use more threads per row by # increasing the number of warps (`num_warps`) over which each row is distributed. # You will see in the next tutorial how to auto-tune this value in a more natural # way so you don't have to come up with manual heuristics yourself. num_warps = 8 # Number of software pipelining stages. num_stages = 4 if SIZE_SMEM > 200000 else 2 # Allocate output y = torch.empty_like(x) # pre-compile kernel to get register usage and compute thread occupancy. kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages, num_warps=num_warps, grid=(1, )) kernel._init_handles() n_regs = kernel.n_regs size_smem = kernel.metadata.shared if is_hip(): # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. # However, this is not always the case. In most cases all registers can be used as regular purpose registers. # ISA SECTION (3.6.4 for CDNA3) # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is # not required to be equal numbers of both types. NUM_GPRS = NUM_REGS if is_cdna(): NUM_GPRS = NUM_REGS * 2 # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. # When we divide this number with WARP_SIZE we get maximum number of waves that can # execute on a CU (multi-processor) in parallel. MAX_NUM_THREADS = properties["max_threads_per_sm"] max_num_waves = MAX_NUM_THREADS // WARP_SIZE occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps else: occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) occupancy = min(occupancy, SIZE_SMEM // size_smem) num_programs = NUM_SM * occupancy num_programs = min(num_programs, n_rows) # Create a number of persistent programs. kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages) return y .. GENERATED FROM PYTHON SOURCE LINES 179-181 Unit Test --------- .. GENERATED FROM PYTHON SOURCE LINES 183-185 We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works. .. GENERATED FROM PYTHON SOURCE LINES 185-192 .. code-block:: Python torch.manual_seed(0) x = torch.randn(1823, 781, device=DEVICE) y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) .. GENERATED FROM PYTHON SOURCE LINES 193-194 As expected, the results are identical. .. GENERATED FROM PYTHON SOURCE LINES 196-201 Benchmark --------- Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. .. GENERATED FROM PYTHON SOURCE LINES 201-231 .. code-block:: Python @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], # argument names to use as an x-axis for the plot x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg`` line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles ylabel="GB/s", # label name for the y-axis plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` )) def benchmark(M, N, provider): x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) stream = getattr(torch, DEVICE.type).Stream() getattr(torch, DEVICE.type).set_stream(stream) if provider == 'torch': ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == 'triton': ms = triton.testing.do_bench(lambda: softmax(x)) if provider == 'naive_softmax': ms = triton.testing.do_bench(lambda: naive_softmax(x)) gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms) benchmark.run(show_plots=True, print_data=True) .. image-sg:: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png :alt: 02 fused softmax :srcset: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none softmax-performance: N Triton (GB/s) Torch (GB/s) Naive Softmax (GB/s) 0 256.0 513.183941 697.262687 207.984032 1 384.0 711.448210 829.159043 263.235187 2 512.0 838.584374 930.272692 304.094435 3 640.0 845.906240 916.991600 332.258092 4 768.0 913.270040 986.345032 350.825352 5 896.0 970.585717 1039.146564 353.103455 6 1024.0 1027.696753 1065.516763 352.298764 7 1152.0 1019.929689 1072.515385 351.043610 8 1280.0 1074.693080 1104.604886 349.251872 9 1408.0 1116.030771 1139.372004 342.030116 10 1536.0 1145.241823 1166.839905 333.219318 11 1664.0 1190.139405 1193.331374 330.554387 12 1792.0 1210.661158 1200.159008 325.764535 13 1920.0 1232.484495 1218.812191 324.519878 14 2048.0 1258.977425 1243.414534 324.752401 15 2176.0 1186.103880 959.550233 325.533561 16 2304.0 1198.352963 999.701653 325.936861 17 2432.0 1222.727766 1037.243763 326.623401 18 2560.0 1240.270372 1068.273922 328.758558 19 2688.0 1261.112896 1101.097139 329.503103 20 2816.0 1277.656512 1127.574586 329.921569 21 2944.0 1288.785419 1148.256433 331.765886 22 3072.0 1302.402476 1171.426131 333.287188 23 3200.0 1319.921005 1168.675938 335.295921 24 3328.0 1326.980827 1199.971246 336.441737 25 3456.0 1342.825564 1221.397569 336.990310 26 3584.0 1345.592798 1246.743509 338.099321 27 3712.0 1346.858384 1263.160610 340.452342 28 3840.0 1351.360374 1284.687134 339.763715 29 3968.0 1364.831247 1298.369834 341.231977 30 4096.0 1371.413840 1319.629208 339.018278 31 4224.0 1336.623352 1275.756159 342.774375 32 4352.0 1347.304535 1301.981009 345.443032 33 4480.0 1352.109940 1322.025312 345.929635 34 4608.0 1364.876968 1337.607615 347.363958 35 4736.0 1364.677158 1342.281738 348.055230 36 4864.0 1374.598124 1358.948392 349.299819 37 4992.0 1373.311031 1372.124599 349.786937 38 5120.0 1384.165506 1387.467650 351.047136 39 5248.0 1377.250772 1355.135424 351.892116 40 5376.0 1383.426895 1370.066146 351.841799 41 5504.0 1390.190224 1387.575359 353.750127 42 5632.0 1398.010413 1390.238351 353.426903 43 5760.0 1397.756101 1405.353462 354.911913 44 5888.0 1396.917036 1410.197174 355.106393 45 6016.0 1408.274749 1420.783086 356.796847 46 6144.0 1412.864898 1436.776557 357.107346 47 6272.0 1414.929475 1391.062628 357.617943 48 6400.0 1417.662160 1406.530510 358.613588 49 6528.0 1422.355433 1423.026626 359.102809 50 6656.0 1421.610406 1424.044824 359.839041 51 6784.0 1426.852255 1438.963006 360.360716 52 6912.0 1432.278066 1444.918389 360.702663 53 7040.0 1424.891871 1457.174004 361.356553 54 7168.0 1428.839177 1454.157996 362.161959 55 7296.0 1429.866355 1087.795005 362.607406 56 7424.0 1437.738422 1097.628661 362.939956 57 7552.0 1433.479343 1106.988031 363.664273 58 7680.0 1432.791942 1121.539657 363.523182 59 7808.0 1434.908687 1128.614732 364.239726 60 7936.0 1437.333241 1142.473968 364.541086 61 8064.0 1439.534483 1148.462722 365.239337 62 8192.0 1434.833063 1154.898261 363.845241 63 8320.0 1386.752986 1117.786420 361.605137 64 8448.0 1384.280820 1123.944237 362.393944 65 8576.0 1389.226155 1129.294474 363.307872 66 8704.0 1380.347858 1135.291856 364.552067 67 8832.0 1393.668994 1133.518716 365.089439 68 8960.0 1384.517301 1140.328966 365.754275 69 9088.0 1397.197827 1135.561895 366.719985 70 9216.0 1403.241970 1145.572096 367.342710 71 9344.0 1392.995017 1418.761336 367.940348 72 9472.0 1399.077128 1432.133729 368.842866 73 9600.0 1400.718604 1434.244686 369.301789 74 9728.0 1404.111921 1440.469768 369.741263 75 9856.0 1400.960859 1439.488831 369.959053 76 9984.0 1393.449866 1446.412693 370.707408 77 10112.0 1407.583531 1452.226544 370.527939 78 10240.0 1409.031442 1467.483041 370.939336 79 10368.0 1416.539980 1461.970805 369.775391 80 10496.0 1407.235159 1463.307013 370.572934 81 10624.0 1407.112494 1466.803719 370.100704 82 10752.0 1398.562784 1470.939602 370.719112 83 10880.0 1396.692204 1481.506478 372.261889 84 11008.0 1418.051254 1475.019222 372.337102 85 11136.0 1420.523821 1485.652107 372.629117 86 11264.0 1415.585521 1487.750966 373.387818 87 11392.0 1422.628758 1489.299318 374.063266 88 11520.0 1414.930773 1496.342830 373.450926 89 11648.0 1422.104103 1498.536584 374.400981 90 11776.0 1433.487736 1502.962543 374.513588 91 11904.0 1429.229709 1509.899910 375.307322 92 12032.0 1416.883328 1510.797982 375.814286 93 12160.0 1417.431843 1513.939569 375.886067 94 12288.0 1425.390422 1420.751343 376.296585 95 12416.0 1434.862105 1397.243009 374.694501 96 12544.0 1441.462524 1397.725060 375.568866 97 12672.0 1433.653880 1392.167809 375.277539 .. GENERATED FROM PYTHON SOURCE LINES 232-236 In the above plot, we can see that: - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 35.024 seconds) .. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 02-fused-softmax.ipynb <02-fused-softmax.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 02-fused-softmax.py <02-fused-softmax.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 02-fused-softmax.zip <02-fused-softmax.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_