.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "getting-started/tutorials/02-fused-softmax.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_getting-started_tutorials_02-fused-softmax.py: Fused Softmax ============= In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM. In doing so, you will learn about: * The benefits of kernel fusion for bandwidth-bound operations. * Reduction operators in Triton. .. GENERATED FROM PYTHON SOURCE LINES 18-23 Motivations ----------- Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation: .. GENERATED FROM PYTHON SOURCE LINES 23-62 .. code-block:: Python import torch import triton import triton.language as tl from triton.runtime import driver DEVICE = triton.runtime.driver.active.get_active_torch_device() def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" def is_cdna(): return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908') def naive_softmax(x): """Compute row-wise softmax of X using native pytorch We subtract the maximum element in order to avoid overflows. Softmax is invariant to this shift. """ # read MN elements ; write M elements x_max = x.max(dim=1)[0] # read MN + M elements ; write MN elements z = x - x_max[:, None] # read MN elements ; write MN elements numerator = torch.exp(z) # read MN elements ; write M elements denominator = numerator.sum(dim=1) # read MN + M elements ; write MN elements ret = numerator / denominator[:, None] # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements return ret .. GENERATED FROM PYTHON SOURCE LINES 63-71 When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal. .. GENERATED FROM PYTHON SOURCE LINES 73-82 Compute Kernel -------------- Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes: .. GENERATED FROM PYTHON SOURCE LINES 82-112 .. code-block:: Python @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr): # starting row of the program row_start = tl.program_id(0) row_step = tl.num_programs(0) for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + row_idx * input_row_stride # The block size is the next power of two greater than n_cols, so we can fit each # row in a single block col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols mask = col_offsets < n_cols row = tl.load(input_ptrs, mask=mask, other=-float('inf')) # Subtract maximum for numerical stability row_minus_max = row - tl.max(row, axis=0) # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # Write back output to DRAM output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=mask) .. GENERATED FROM PYTHON SOURCE LINES 113-114 We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. .. GENERATED FROM PYTHON SOURCE LINES 114-178 .. code-block:: Python properties = driver.active.utils.get_device_properties(DEVICE.index) NUM_SM = properties["multiprocessor_count"] NUM_REGS = properties["max_num_regs"] SIZE_SMEM = properties["max_shared_mem"] WARP_SIZE = properties["warpSize"] target = triton.runtime.driver.active.get_current_target() kernels = {} def softmax(x): n_rows, n_cols = x.shape # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` BLOCK_SIZE = triton.next_power_of_2(n_cols) # Another trick we can use is to ask the compiler to use more threads per row by # increasing the number of warps (`num_warps`) over which each row is distributed. # You will see in the next tutorial how to auto-tune this value in a more natural # way so you don't have to come up with manual heuristics yourself. num_warps = 8 # Number of software pipelining stages. num_stages = 4 if SIZE_SMEM > 200000 else 2 # Allocate output y = torch.empty_like(x) # pre-compile kernel to get register usage and compute thread occupancy. kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages, num_warps=num_warps, grid=(1, )) kernel._init_handles() n_regs = kernel.n_regs size_smem = kernel.metadata.shared if is_hip(): # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. # However, this is not always the case. In most cases all registers can be used as regular purpose registers. # ISA SECTION (3.6.4 for CDNA3) # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is # not required to be equal numbers of both types. NUM_GPRS = NUM_REGS if is_cdna(): NUM_GPRS = NUM_REGS * 2 # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. # When we divide this number with WARP_SIZE we get maximum number of waves that can # execute on a CU (multi-processor) in parallel. MAX_NUM_THREADS = properties["max_threads_per_sm"] max_num_waves = MAX_NUM_THREADS // WARP_SIZE occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps else: occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) occupancy = min(occupancy, SIZE_SMEM // size_smem) num_programs = NUM_SM * occupancy num_programs = min(num_programs, n_rows) # Create a number of persistent programs. kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages) return y .. GENERATED FROM PYTHON SOURCE LINES 179-181 Unit Test --------- .. GENERATED FROM PYTHON SOURCE LINES 183-185 We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works. .. GENERATED FROM PYTHON SOURCE LINES 185-192 .. code-block:: Python torch.manual_seed(0) x = torch.randn(1823, 781, device=DEVICE) y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) .. GENERATED FROM PYTHON SOURCE LINES 193-194 As expected, the results are identical. .. GENERATED FROM PYTHON SOURCE LINES 196-201 Benchmark --------- Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. .. GENERATED FROM PYTHON SOURCE LINES 201-231 .. code-block:: Python @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], # argument names to use as an x-axis for the plot x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg`` line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles ylabel="GB/s", # label name for the y-axis plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` )) def benchmark(M, N, provider): x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) stream = getattr(torch, DEVICE.type).Stream() getattr(torch, DEVICE.type).set_stream(stream) if provider == 'torch': ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == 'triton': ms = triton.testing.do_bench(lambda: softmax(x)) if provider == 'naive_softmax': ms = triton.testing.do_bench(lambda: naive_softmax(x)) gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms) benchmark.run(show_plots=True, print_data=True) .. image-sg:: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png :alt: 02 fused softmax :srcset: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none softmax-performance: N Triton Torch Naive Softmax 0 256.0 469.514462 701.720876 203.733288 1 384.0 654.273227 814.305270 261.694028 2 512.0 814.417367 928.773532 298.986900 3 640.0 917.491226 911.193043 328.589100 4 768.0 986.414125 982.018834 348.526030 5 896.0 1042.315225 1034.205096 356.558861 6 1024.0 1080.392793 1076.108885 353.733851 7 1152.0 1089.306537 1068.578444 349.904604 8 1280.0 1139.712447 1110.631314 349.639201 9 1408.0 1170.285736 1132.626223 341.646845 10 1536.0 1187.455192 1165.838004 332.033070 11 1664.0 1208.278009 1179.603421 329.516358 12 1792.0 1229.608681 1199.153274 325.422930 13 1920.0 1260.731186 1225.723310 324.379890 14 2048.0 1273.346472 1243.064685 324.114739 15 2176.0 1239.302547 962.712717 326.065782 16 2304.0 1259.799550 1004.674545 326.401742 17 2432.0 1269.606628 1034.415417 326.612020 18 2560.0 1288.642442 1069.035893 328.668815 19 2688.0 1290.071848 1097.264595 329.496778 20 2816.0 1310.594856 1121.737068 329.878282 21 2944.0 1316.235613 1141.121541 331.220775 22 3072.0 1324.469228 1168.303611 333.086162 23 3200.0 1338.946308 1175.114069 334.449520 24 3328.0 1347.462515 1199.742674 336.440978 25 3456.0 1355.547201 1224.777722 336.564395 26 3584.0 1357.416710 1246.009689 338.488583 27 3712.0 1364.807717 1263.727098 340.733560 28 3840.0 1375.680313 1283.787260 340.581597 29 3968.0 1371.808258 1303.471637 341.438661 30 4096.0 1390.020080 1313.449279 338.896535 31 4224.0 1327.725801 1280.157445 342.906846 32 4352.0 1339.358434 1302.072847 344.947062 33 4480.0 1346.190345 1316.456486 345.141162 34 4608.0 1355.613413 1334.011122 346.772848 35 4736.0 1355.536089 1344.109839 347.990477 36 4864.0 1366.297629 1357.748299 348.901296 37 4992.0 1373.589693 1373.475469 350.382606 38 5120.0 1380.149639 1388.702766 351.196480 39 5248.0 1374.961307 1359.224786 351.915750 40 5376.0 1384.165510 1370.680326 352.005988 41 5504.0 1378.635705 1369.409242 354.070907 42 5632.0 1395.092244 1399.798076 352.869594 43 5760.0 1393.334113 1403.199670 354.921211 44 5888.0 1387.075676 1419.912292 354.665410 45 6016.0 1402.339784 1414.110128 357.038277 46 6144.0 1406.697822 1431.131480 356.598435 47 6272.0 1406.089541 1392.576085 357.811982 48 6400.0 1417.440270 1418.675514 358.923211 49 6528.0 1413.994817 1424.850232 359.365592 50 6656.0 1414.435870 1424.811153 359.575137 51 6784.0 1414.743063 1444.718721 360.371846 52 6912.0 1423.728339 1441.693653 361.084494 53 7040.0 1419.575185 1458.508917 360.667834 54 7168.0 1420.152482 1461.927728 361.190195 55 7296.0 1429.007992 1087.540864 362.465524 56 7424.0 1427.455451 1101.934524 363.076354 57 7552.0 1430.168058 1111.906556 363.358978 58 7680.0 1432.008878 1126.135414 363.773756 59 7808.0 1431.983015 1134.267066 364.794248 60 7936.0 1430.204821 1144.594469 365.046874 61 8064.0 1434.933376 1151.257408 365.148665 62 8192.0 1432.900342 1156.180190 364.599759 63 8320.0 1382.324490 1118.711652 361.958201 64 8448.0 1386.130160 1126.996306 362.715247 65 8576.0 1388.768226 1127.953488 363.396962 66 8704.0 1378.257060 1136.287941 363.831087 67 8832.0 1393.074619 1134.647144 365.044952 68 8960.0 1385.884468 1142.441331 365.958793 69 9088.0 1398.805494 1140.754383 367.044866 70 9216.0 1407.487792 1143.821205 367.080799 71 9344.0 1389.840604 1419.582122 367.048764 72 9472.0 1400.307904 1433.064105 368.730778 73 9600.0 1397.889614 1432.155143 369.262150 74 9728.0 1398.822236 1434.238747 369.021508 75 9856.0 1400.746135 1437.656097 369.280271 76 9984.0 1396.662618 1451.131488 370.476223 77 10112.0 1402.074341 1453.712692 371.441977 78 10240.0 1410.893427 1463.382786 370.536043 79 10368.0 1416.413183 1457.999997 369.615927 80 10496.0 1408.770788 1467.478260 370.893188 81 10624.0 1406.982464 1465.203166 370.943849 82 10752.0 1394.997234 1471.004989 371.639663 83 10880.0 1398.910984 1474.626878 371.661389 84 11008.0 1419.819526 1473.553882 371.732265 85 11136.0 1418.711148 1481.699888 371.970535 86 11264.0 1412.924493 1486.571172 372.633543 87 11392.0 1416.822112 1490.902530 373.679827 88 11520.0 1414.108433 1493.294048 373.609882 89 11648.0 1418.010574 1495.908583 373.826687 90 11776.0 1436.228342 1501.767729 375.014970 91 11904.0 1428.179762 1507.490890 375.593671 92 12032.0 1415.695771 1507.968597 375.885013 93 12160.0 1413.413372 1514.455455 375.333370 94 12288.0 1426.707205 1416.370757 375.449361 95 12416.0 1435.116494 1395.422561 374.668000 96 12544.0 1445.430160 1391.137193 374.740526 97 12672.0 1436.434092 1391.890159 374.785805 .. GENERATED FROM PYTHON SOURCE LINES 232-236 In the above plot, we can see that: - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 35.929 seconds) .. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 02-fused-softmax.ipynb <02-fused-softmax.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 02-fused-softmax.py <02-fused-softmax.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 02-fused-softmax.zip <02-fused-softmax.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_