.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "getting-started/tutorials/02-fused-softmax.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_getting-started_tutorials_02-fused-softmax.py: Fused Softmax ============= In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM. In doing so, you will learn about: * The benefits of kernel fusion for bandwidth-bound operations. * Reduction operators in Triton. .. GENERATED FROM PYTHON SOURCE LINES 18-23 Motivations ----------- Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation: .. GENERATED FROM PYTHON SOURCE LINES 23-62 .. code-block:: Python import torch import triton import triton.language as tl from triton.runtime import driver DEVICE = triton.runtime.driver.active.get_active_torch_device() def is_hip(): return triton.runtime.driver.active.get_current_target().backend == "hip" def is_cdna(): return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx940', 'gfx941', 'gfx942', 'gfx90a', 'gfx908') def naive_softmax(x): """Compute row-wise softmax of X using native pytorch We subtract the maximum element in order to avoid overflows. Softmax is invariant to this shift. """ # read MN elements ; write M elements x_max = x.max(dim=1)[0] # read MN + M elements ; write MN elements z = x - x_max[:, None] # read MN elements ; write MN elements numerator = torch.exp(z) # read MN elements ; write M elements denominator = numerator.sum(dim=1) # read MN + M elements ; write MN elements ret = numerator / denominator[:, None] # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements return ret .. GENERATED FROM PYTHON SOURCE LINES 63-71 When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal. .. GENERATED FROM PYTHON SOURCE LINES 73-82 Compute Kernel -------------- Our softmax kernel works as follows: each program loads a set of rows of the input matrix X strided by number of programs, normalizes it and writes back the result to the output Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes: .. GENERATED FROM PYTHON SOURCE LINES 82-112 .. code-block:: Python @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_rows, n_cols, BLOCK_SIZE: tl.constexpr, num_stages: tl.constexpr): # starting row of the program row_start = tl.program_id(0) row_step = tl.num_programs(0) for row_idx in tl.range(row_start, n_rows, row_step, num_stages=num_stages): # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + row_idx * input_row_stride # The block size is the next power of two greater than n_cols, so we can fit each # row in a single block col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols mask = col_offsets < n_cols row = tl.load(input_ptrs, mask=mask, other=-float('inf')) # Subtract maximum for numerical stability row_minus_max = row - tl.max(row, axis=0) # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # Write back output to DRAM output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=mask) .. GENERATED FROM PYTHON SOURCE LINES 113-114 We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. .. GENERATED FROM PYTHON SOURCE LINES 114-178 .. code-block:: Python properties = driver.active.utils.get_device_properties(DEVICE.index) NUM_SM = properties["multiprocessor_count"] NUM_REGS = properties["max_num_regs"] SIZE_SMEM = properties["max_shared_mem"] WARP_SIZE = properties["warpSize"] target = triton.runtime.driver.active.get_current_target() kernels = {} def softmax(x): n_rows, n_cols = x.shape # The block size of each loop iteration is the smallest power of two greater than the number of columns in `x` BLOCK_SIZE = triton.next_power_of_2(n_cols) # Another trick we can use is to ask the compiler to use more threads per row by # increasing the number of warps (`num_warps`) over which each row is distributed. # You will see in the next tutorial how to auto-tune this value in a more natural # way so you don't have to come up with manual heuristics yourself. num_warps = 8 # Number of software pipelining stages. num_stages = 4 if SIZE_SMEM > 200000 else 2 # Allocate output y = torch.empty_like(x) # pre-compile kernel to get register usage and compute thread occupancy. kernel = softmax_kernel.warmup(y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE, num_stages=num_stages, num_warps=num_warps, grid=(1, )) kernel._init_handles() n_regs = kernel.n_regs size_smem = kernel.metadata.shared if is_hip(): # NUM_REGS represents the number of regular purpose registers. On CDNA architectures this is half of all registers available. # However, this is not always the case. In most cases all registers can be used as regular purpose registers. # ISA SECTION (3.6.4 for CDNA3) # VGPRs are allocated out of two pools: regular VGPRs and accumulation VGPRs. Accumulation VGPRs are used # with matrix VALU instructions, and can also be loaded directly from memory. A wave may have up to 512 total # VGPRs, 256 of each type. When a wave has fewer than 512 total VGPRs, the number of each type is flexible - it is # not required to be equal numbers of both types. NUM_GPRS = NUM_REGS if is_cdna(): NUM_GPRS = NUM_REGS * 2 # MAX_NUM_THREADS represents maximum number of resident threads per multi-processor. # When we divide this number with WARP_SIZE we get maximum number of waves that can # execute on a CU (multi-processor) in parallel. MAX_NUM_THREADS = properties["max_threads_per_sm"] max_num_waves = MAX_NUM_THREADS // WARP_SIZE occupancy = min(NUM_GPRS // WARP_SIZE // n_regs, max_num_waves) // num_warps else: occupancy = NUM_REGS // (n_regs * WARP_SIZE * num_warps) occupancy = min(occupancy, SIZE_SMEM // size_smem) num_programs = NUM_SM * occupancy num_programs = min(num_programs, n_rows) # Create a number of persistent programs. kernel[(num_programs, 1, 1)](y, x, x.stride(0), y.stride(0), n_rows, n_cols, BLOCK_SIZE, num_stages) return y .. GENERATED FROM PYTHON SOURCE LINES 179-181 Unit Test --------- .. GENERATED FROM PYTHON SOURCE LINES 183-185 We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works. .. GENERATED FROM PYTHON SOURCE LINES 185-192 .. code-block:: Python torch.manual_seed(0) x = torch.randn(1823, 781, device=DEVICE) y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) .. GENERATED FROM PYTHON SOURCE LINES 193-194 As expected, the results are identical. .. GENERATED FROM PYTHON SOURCE LINES 196-201 Benchmark --------- Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. .. GENERATED FROM PYTHON SOURCE LINES 201-231 .. code-block:: Python @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], # argument names to use as an x-axis for the plot x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot line_vals=['triton', 'torch', 'naive_softmax'], # possible values for `line_arg`` line_names=["Triton", "Torch", "Naive Softmax"], # label name for the lines styles=[('blue', '-'), ('green', '-'), ('red', '-')], # line styles ylabel="GB/s", # label name for the y-axis plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` )) def benchmark(M, N, provider): x = torch.randn(M, N, device=DEVICE, dtype=torch.float32) stream = getattr(torch, DEVICE.type).Stream() getattr(torch, DEVICE.type).set_stream(stream) if provider == 'torch': ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == 'triton': ms = triton.testing.do_bench(lambda: softmax(x)) if provider == 'naive_softmax': ms = triton.testing.do_bench(lambda: naive_softmax(x)) gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms) benchmark.run(show_plots=True, print_data=True) .. image-sg:: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png :alt: 02 fused softmax :srcset: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none softmax-performance: N Triton Torch Naive Softmax 0 256.0 464.494833 689.887701 204.117691 1 384.0 656.365253 817.772358 261.865478 2 512.0 809.777647 917.699776 301.214046 3 640.0 917.229944 906.743348 329.711966 4 768.0 985.714290 986.275948 348.097695 5 896.0 1054.692093 1043.861067 356.143757 6 1024.0 1089.366968 1073.377807 354.390369 7 1152.0 1103.254821 1074.422131 349.208118 8 1280.0 1129.322993 1102.231277 349.214586 9 1408.0 1163.484730 1141.097625 341.398909 10 1536.0 1190.414296 1166.203955 332.447303 11 1664.0 1207.133491 1191.474794 329.028693 12 1792.0 1227.466998 1199.756340 324.882040 13 1920.0 1261.978259 1223.954166 324.547100 14 2048.0 1274.462283 1250.811409 324.864095 15 2176.0 1232.505554 958.403117 325.739931 16 2304.0 1261.144185 1000.426686 326.300233 17 2432.0 1271.108788 1034.408114 327.151966 18 2560.0 1283.158464 1072.505303 328.559872 19 2688.0 1298.229090 1097.078556 329.438697 20 2816.0 1314.152587 1127.056822 329.654716 21 2944.0 1319.041275 1139.299449 331.140618 22 3072.0 1321.341286 1167.074433 332.832861 23 3200.0 1337.628426 1172.370275 334.344933 24 3328.0 1340.777014 1201.293522 336.659289 25 3456.0 1347.137693 1217.946985 336.809528 26 3584.0 1358.604680 1247.717974 338.523726 27 3712.0 1366.010794 1266.415280 340.740559 28 3840.0 1369.867995 1280.033655 340.729201 29 3968.0 1374.929376 1301.219897 341.353283 30 4096.0 1391.012730 1320.556837 338.883019 31 4224.0 1333.229085 1274.952259 343.155065 32 4352.0 1337.294903 1293.840698 345.122916 33 4480.0 1350.463267 1314.006861 345.776189 34 4608.0 1364.220882 1333.172589 346.803478 35 4736.0 1357.999729 1341.524925 347.783735 36 4864.0 1367.648508 1357.512641 349.067675 37 4992.0 1366.739122 1373.290949 350.258541 38 5120.0 1375.252302 1388.179636 351.564108 39 5248.0 1376.723732 1348.162625 351.972486 40 5376.0 1380.962714 1372.791127 351.912573 41 5504.0 1379.313181 1382.563595 353.637979 42 5632.0 1393.867444 1395.198487 353.329247 43 5760.0 1393.123687 1406.574209 355.037478 44 5888.0 1390.238996 1418.640385 354.989148 45 6016.0 1404.114897 1415.033965 356.643239 46 6144.0 1409.423850 1423.209358 357.039429 47 6272.0 1409.376338 1399.528622 357.709945 48 6400.0 1415.114493 1403.859122 358.678241 49 6528.0 1413.071010 1420.788514 359.010696 50 6656.0 1410.027636 1426.964017 359.705531 51 6784.0 1418.185885 1442.042203 360.206258 52 6912.0 1426.206600 1444.386403 360.918781 53 7040.0 1422.870169 1453.720945 360.855824 54 7168.0 1419.942081 1456.791865 361.773392 55 7296.0 1422.876969 1085.674948 362.319183 56 7424.0 1426.284958 1094.820523 362.964861 57 7552.0 1423.301006 1111.019671 363.336214 58 7680.0 1433.699555 1125.547802 364.079824 59 7808.0 1427.350714 1131.753417 364.766938 60 7936.0 1437.568487 1138.856246 364.618458 61 8064.0 1436.850026 1145.052631 364.890497 62 8192.0 1432.452514 1150.195552 364.424825 63 8320.0 1384.319221 1117.010960 361.605138 64 8448.0 1381.862798 1126.905099 362.509984 65 8576.0 1390.546122 1126.074320 363.321233 66 8704.0 1384.876305 1129.847527 364.106681 67 8832.0 1397.379708 1134.563482 364.658374 68 8960.0 1383.546595 1138.316851 365.621015 69 9088.0 1396.658666 1135.711925 366.448948 70 9216.0 1407.358622 1143.013462 367.569410 71 9344.0 1390.482255 1421.328601 367.841751 72 9472.0 1402.240933 1431.741578 368.694807 73 9600.0 1400.405216 1430.578657 369.323986 74 9728.0 1397.607405 1440.956045 369.566893 75 9856.0 1404.321327 1441.340452 370.088841 76 9984.0 1395.292711 1445.367142 370.716371 77 10112.0 1403.625582 1455.459976 371.255529 78 10240.0 1412.063130 1467.971402 371.375935 79 10368.0 1411.251288 1460.491691 370.179094 80 10496.0 1406.684157 1466.162689 370.737439 81 10624.0 1408.797884 1468.619786 371.170821 82 10752.0 1393.502079 1473.106921 371.872724 83 10880.0 1399.571212 1479.129153 372.169036 84 11008.0 1421.185023 1478.115338 372.983841 85 11136.0 1416.411054 1484.572782 372.894978 86 11264.0 1413.616742 1487.464483 373.507895 87 11392.0 1423.588975 1489.828751 373.499399 88 11520.0 1418.735579 1497.413974 374.240550 89 11648.0 1417.417087 1497.081407 374.994791 90 11776.0 1435.745824 1501.248252 374.979428 91 11904.0 1432.396716 1508.562711 375.708809 92 12032.0 1413.758016 1509.847645 375.845225 93 12160.0 1411.196955 1516.607497 376.134201 94 12288.0 1425.814558 1419.030985 375.821491 95 12416.0 1437.263881 1395.864226 375.154437 96 12544.0 1443.486896 1394.456954 375.757835 97 12672.0 1440.237882 1392.959954 375.237975 .. GENERATED FROM PYTHON SOURCE LINES 232-236 In the above plot, we can see that: - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 36.023 seconds) .. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 02-fused-softmax.ipynb <02-fused-softmax.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 02-fused-softmax.py <02-fused-softmax.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 02-fused-softmax.zip <02-fused-softmax.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_