.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "getting-started/tutorials/02-fused-softmax.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_getting-started_tutorials_02-fused-softmax.py: Fused Softmax ============= In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM. In doing so, you will learn about: * The benefits of kernel fusion for bandwidth-bound operations. * Reduction operators in Triton. .. GENERATED FROM PYTHON SOURCE LINES 18-23 Motivations ----------- Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation: .. GENERATED FROM PYTHON SOURCE LINES 23-62 .. code-block:: Python import torch import triton import triton.language as tl from triton.runtime import driver DEVICE = triton.runtime.driver.active.get_active_torch_device() def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" def is_cdna(): return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908') def naive_softmax(x): """Compute row-wise softmax of X using native pytorch We subtract the maximum element in order to avoid overflows. Softmax is invariant to this shift. """ # read MN elements ; write M elements x_max = x.max(dim=1)[0] # read MN + M elements ; write MN elements z = x - x_max[:, None] # read MN elements ; write MN elements numerator = torch.exp(z) # read MN elements ; write M elements denominator = numerator.sum(dim=1) # read MN + M elements ; write MN elements ret = numerator / denominator[:, None] # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements return ret .. GENERATED FROM PYTHON SOURCE LINES 63-71 When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal. .. GENERATED FROM PYTHON SOURCE LINES 73-82 Compute Kernel -------------- Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes: .. GENERATED FROM PYTHON SOURCE LINES 82-112 .. code-block:: Python @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr): # starting row of the program row_start = tl.program_id(0) row_step = tl.num_programs(0) for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + row_idx * input_row_stride # The block size is the next power of two greater than n_cols, so we can fit each # row in a single block col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols mask = col_offsets < n_cols row = tl.load(input_ptrs, mask=mask, other=-float('inf')) # Subtract maximum for numerical stability row_minus_max = row - tl.max(row, axis=0) # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # Write back output to DRAM output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=mask) .. GENERATED FROM PYTHON SOURCE LINES 113-114 We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. .. GENERATED FROM PYTHON SOURCE LINES 114-178 .. code-block:: Python properties = driver.active.utils.get_device_properties(DEVICE.index) NUM_SM = properties["multiprocessor_count"] NUM_REGS = properties["max_num_regs"] SIZE_SMEM = properties["max_shared_mem"] WARP_SIZE = properties["warpSize"] target = triton.runtime.driver.active.get_current_target() kernels = {} def softmax(x): n_rows, n_cols = x.shape # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` BLOCK_SIZE = triton.next_power_of_2(n_cols) # Another trick we can use is to ask the compiler to use more threads per row by # increasing the number of warps (`num_warps`) over which each row is distributed. # You will see in the next tutorial how to auto-tune this value in a more natural # way so you don't have to come up with manual heuristics yourself. num_warps = 8 # Number of software pipelining stages. num_stages = 4 if SIZE_SMEM > 200000 else 2 # Allocate output y = torch.empty_like(x) # pre-compile kernel to get register usage and compute thread occupancy. kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages, num_warps=num_warps, grid=(1, )) kernel._init_handles() n_regs = kernel.n_regs size_smem = kernel.metadata.shared if is_hip(): # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. # However, this is not always the case. In most cases all registers can be used as regular purpose registers. # ISA SECTION (3.6.4 for CDNA3) # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is # not required to be equal numbers of both types. NUM_GPRS = NUM_REGS if is_cdna(): NUM_GPRS = NUM_REGS * 2 # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. # When we divide this number with WARP_SIZE we get maximum number of waves that can # execute on a CU (multi-processor) in parallel. MAX_NUM_THREADS = properties["max_threads_per_sm"] max_num_waves = MAX_NUM_THREADS // WARP_SIZE occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps else: occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) occupancy = min(occupancy, SIZE_SMEM // size_smem) num_programs = NUM_SM * occupancy num_programs = min(num_programs, n_rows) # Create a number of persistent programs. kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages) return y .. GENERATED FROM PYTHON SOURCE LINES 179-181 Unit Test --------- .. GENERATED FROM PYTHON SOURCE LINES 183-185 We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works. .. GENERATED FROM PYTHON SOURCE LINES 185-192 .. code-block:: Python torch.manual_seed(0) x = torch.randn(1823, 781, device=DEVICE) y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) .. GENERATED FROM PYTHON SOURCE LINES 193-194 As expected, the results are identical. .. GENERATED FROM PYTHON SOURCE LINES 196-201 Benchmark --------- Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. .. GENERATED FROM PYTHON SOURCE LINES 201-231 .. code-block:: Python @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], # argument names to use as an x-axis for the plot x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg`` line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles ylabel="GB/s", # label name for the y-axis plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` )) def benchmark(M, N, provider): x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) stream = getattr(torch, DEVICE.type).Stream() getattr(torch, DEVICE.type).set_stream(stream) if provider == 'torch': ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == 'triton': ms = triton.testing.do_bench(lambda: softmax(x)) if provider == 'naive_softmax': ms = triton.testing.do_bench(lambda: naive_softmax(x)) gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms) benchmark.run(show_plots=True, print_data=True) .. image-sg:: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png :alt: 02 fused softmax :srcset: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none softmax-performance: N Triton (GB/s) Torch (GB/s) Naive Softmax (GB/s) 0 256.0 474.202004 693.015690 207.634947 1 384.0 669.058687 831.590195 262.853008 2 512.0 804.970735 915.383841 300.607205 3 640.0 921.944720 928.511893 329.328702 4 768.0 986.206881 985.175696 348.703260 5 896.0 1052.594410 1037.775777 355.388074 6 1024.0 1085.627944 1080.391408 356.026482 7 1152.0 1106.896050 1077.615152 350.571411 8 1280.0 1130.668827 1112.478405 349.429682 9 1408.0 1171.411567 1141.511466 342.048678 10 1536.0 1199.439915 1160.034499 333.987379 11 1664.0 1210.297871 1184.309365 329.473583 12 1792.0 1230.411543 1200.481739 325.927118 13 1920.0 1260.197226 1218.058495 324.405743 14 2048.0 1274.380691 1252.414336 324.287455 15 2176.0 1240.894576 958.345008 325.746776 16 2304.0 1257.237503 998.456304 326.638801 17 2432.0 1275.655325 1035.347344 326.689151 18 2560.0 1283.204353 1068.862861 328.588619 19 2688.0 1299.808020 1098.397833 329.078285 20 2816.0 1308.060779 1121.319050 329.807261 21 2944.0 1320.965402 1140.697107 331.545978 22 3072.0 1318.204522 1170.885973 334.068603 23 3200.0 1338.798603 1171.396036 334.960246 24 3328.0 1341.965806 1201.979794 336.006963 25 3456.0 1354.159141 1220.608874 336.202385 26 3584.0 1356.258914 1244.494008 338.156005 27 3712.0 1362.456922 1265.083345 340.467085 28 3840.0 1369.647783 1280.794950 340.292034 29 3968.0 1376.735565 1294.928124 341.090176 30 4096.0 1387.314334 1315.001435 338.484385 31 4224.0 1329.135686 1274.530249 343.460460 32 4352.0 1345.297315 1300.555120 345.563337 33 4480.0 1349.045102 1318.002397 345.830772 34 4608.0 1362.873656 1333.663218 346.730572 35 4736.0 1358.061374 1341.719627 347.441785 36 4864.0 1364.064769 1354.769797 348.586105 37 4992.0 1370.584602 1371.468763 350.705204 38 5120.0 1380.397185 1383.514122 351.228146 39 5248.0 1377.907563 1351.973292 351.980276 40 5376.0 1377.929199 1366.247064 351.675930 41 5504.0 1385.560731 1380.929227 353.943679 42 5632.0 1394.456776 1396.710872 352.772246 43 5760.0 1393.402804 1404.570410 355.372754 44 5888.0 1394.411300 1413.818713 354.985591 45 6016.0 1402.452183 1412.746967 356.503707 46 6144.0 1408.637104 1432.645448 356.616913 47 6272.0 1408.213825 1400.839398 357.816605 48 6400.0 1416.523770 1414.483030 358.465895 49 6528.0 1419.554254 1412.671037 359.637999 50 6656.0 1419.183652 1432.277703 359.454697 51 6784.0 1416.643987 1429.138790 360.156798 52 6912.0 1422.041776 1444.805017 360.601590 53 7040.0 1419.271030 1455.509398 361.075758 54 7168.0 1422.176061 1450.303580 361.878558 55 7296.0 1426.762687 1082.160100 362.004036 56 7424.0 1431.547410 1097.962219 362.526844 57 7552.0 1425.133524 1106.707837 363.090562 58 7680.0 1428.345316 1119.019612 363.704017 59 7808.0 1430.349212 1130.202089 364.817012 60 7936.0 1430.758228 1139.473082 364.450101 61 8064.0 1436.306829 1147.306546 365.139599 62 8192.0 1432.648036 1149.114691 364.251501 63 8320.0 1384.813629 1114.656971 361.855338 64 8448.0 1386.679595 1122.359517 362.637008 65 8576.0 1389.880988 1126.096439 363.762687 66 8704.0 1385.089706 1132.663440 364.458445 67 8832.0 1392.078813 1131.701907 365.022713 68 8960.0 1383.475225 1137.550174 365.319322 69 9088.0 1396.748351 1137.087012 366.577753 70 9216.0 1400.802363 1141.644786 367.507150 71 9344.0 1388.157273 1417.871043 367.787952 72 9472.0 1404.481684 1432.326586 368.735175 73 9600.0 1401.651942 1430.628702 369.261837 74 9728.0 1402.310650 1440.777857 370.068998 75 9856.0 1398.177150 1442.239221 369.467579 76 9984.0 1396.143565 1451.272155 370.542670 77 10112.0 1401.824266 1453.255822 371.446418 78 10240.0 1406.646422 1461.334077 371.687840 79 10368.0 1415.045600 1458.125304 370.032596 80 10496.0 1410.147334 1467.762549 370.639606 81 10624.0 1403.573083 1463.744512 370.912718 82 10752.0 1399.014872 1471.071614 371.031025 83 10880.0 1394.045961 1478.994165 372.135565 84 11008.0 1417.933247 1478.914649 373.112573 85 11136.0 1418.076112 1485.180439 373.099061 86 11264.0 1415.722746 1485.462179 372.810745 87 11392.0 1420.911065 1490.045737 373.759096 88 11520.0 1418.961396 1498.997434 374.063922 89 11648.0 1424.253938 1497.820064 374.228506 90 11776.0 1431.592938 1501.903403 374.385129 91 11904.0 1431.867907 1505.895399 375.285313 92 12032.0 1415.342745 1511.088155 375.531644 93 12160.0 1411.386836 1516.998972 375.908209 94 12288.0 1427.114975 1422.788053 376.007834 95 12416.0 1438.054397 1397.360475 374.721002 96 12544.0 1441.766818 1395.472580 375.647945 97 12672.0 1436.024271 1392.099167 375.158878 .. GENERATED FROM PYTHON SOURCE LINES 232-236 In the above plot, we can see that: - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 35.025 seconds) .. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 02-fused-softmax.ipynb <02-fused-softmax.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 02-fused-softmax.py <02-fused-softmax.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 02-fused-softmax.zip <02-fused-softmax.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_