.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "getting-started/tutorials/02-fused-softmax.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_getting-started_tutorials_02-fused-softmax.py: Fused Softmax ============= In this tutorial, you will write a fused softmax operation that is significantly faster than PyTorch's native op for a particular class of matrices: those whose rows can fit in the GPU's SRAM. In doing so, you will learn about: * The benefits of kernel fusion for bandwidth-bound operations. * Reduction operators in Triton. .. GENERATED FROM PYTHON SOURCE LINES 18-23 Motivations ----------- Custom GPU kernels for elementwise additions are educationally valuable but won't get you very far in practice. Let us consider instead the case of a simple (numerically stabilized) softmax operation: .. GENERATED FROM PYTHON SOURCE LINES 23-51 .. code-block:: Python import torch import triton import triton.language as tl @torch.jit.script def naive_softmax(x): """Compute row-wise softmax of X using native pytorch We subtract the maximum element in order to avoid overflows. Softmax is invariant to this shift. """ # read MN elements ; write M elements x_max = x.max(dim=1)[0] # read MN + M elements ; write MN elements z = x - x_max[:, None] # read MN elements ; write MN elements numerator = torch.exp(z) # read MN elements ; write M elements denominator = numerator.sum(dim=1) # read MN + M elements ; write MN elements ret = numerator / denominator[:, None] # in total: read 5MN + 2M elements ; wrote 3MN + 2M elements return ret .. GENERATED FROM PYTHON SOURCE LINES 52-60 When implemented naively in PyTorch, computing :code:`y = naive_softmax(x)` for :math:`x \in R^{M \times N}` requires reading :math:`5MN + 2M` elements from DRAM and writing back :math:`3MN + 2M` elements. This is obviously wasteful; we'd prefer to have a custom "fused" kernel that only reads X once and does all the necessary computations on-chip. Doing so would require reading and writing back only :math:`MN` bytes, so we could expect a theoretical speed-up of ~4x (i.e., :math:`(8MN + 4M) / 2MN`). The `torch.jit.script` flags aims to perform this kind of "kernel fusion" automatically but, as we will see later, it is still far from ideal. .. GENERATED FROM PYTHON SOURCE LINES 62-71 Compute Kernel -------------- Our softmax kernel works as follows: each program loads a row of the input matrix X, normalizes it and writes back the result to the output Y. Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally "pad" each row and guard the memory operations properly if we want to handle any possible input shapes: .. GENERATED FROM PYTHON SOURCE LINES 71-97 .. code-block:: Python @triton.jit def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): # The rows of the softmax are independent, so we parallelize across those row_idx = tl.program_id(0) # The stride represents how much we need to increase the pointer to advance 1 row row_start_ptr = input_ptr + row_idx * input_row_stride # The block size is the next power of two greater than n_cols, so we can fit each # row in a single block col_offsets = tl.arange(0, BLOCK_SIZE) input_ptrs = row_start_ptr + col_offsets # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')) # Subtract maximum for numerical stability row_minus_max = row - tl.max(row, axis=0) # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA) numerator = tl.exp(row_minus_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # Write back output to DRAM output_row_start_ptr = output_ptr + row_idx * output_row_stride output_ptrs = output_row_start_ptr + col_offsets tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) .. GENERATED FROM PYTHON SOURCE LINES 98-99 We can create a helper function that enqueues the kernel and its (meta-)arguments for any given input tensor. .. GENERATED FROM PYTHON SOURCE LINES 99-130 .. code-block:: Python def softmax(x): n_rows, n_cols = x.shape # The block size is the smallest power of two greater than the number of columns in `x` BLOCK_SIZE = triton.next_power_of_2(n_cols) # Another trick we can use is to ask the compiler to use more threads per row by # increasing the number of warps (`num_warps`) over which each row is distributed. # You will see in the next tutorial how to auto-tune this value in a more natural # way so you don't have to come up with manual heuristics yourself. num_warps = 4 if BLOCK_SIZE >= 2048: num_warps = 8 if BLOCK_SIZE >= 4096: num_warps = 16 # Allocate output y = torch.empty_like(x) # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o # f the input matrix softmax_kernel[(n_rows, )]( y, x, x.stride(0), y.stride(0), n_cols, num_warps=num_warps, BLOCK_SIZE=BLOCK_SIZE, ) return y .. GENERATED FROM PYTHON SOURCE LINES 131-133 Unit Test --------- .. GENERATED FROM PYTHON SOURCE LINES 135-137 We make sure that we test our kernel on a matrix with an irregular number of rows and columns. This will allow us to verify that our padding mechanism works. .. GENERATED FROM PYTHON SOURCE LINES 137-144 .. code-block:: Python torch.manual_seed(0) x = torch.randn(1823, 781, device='cuda') y_triton = softmax(x) y_torch = torch.softmax(x, axis=1) assert torch.allclose(y_triton, y_torch), (y_triton, y_torch) .. GENERATED FROM PYTHON SOURCE LINES 145-146 As expected, the results are identical. .. GENERATED FROM PYTHON SOURCE LINES 148-153 Benchmark --------- Here we will benchmark our operation as a function of the number of columns in the input matrix -- assuming 4096 rows. We will then compare its performance against (1) :code:`torch.softmax` and (2) the :code:`naive_softmax` defined above. .. GENERATED FROM PYTHON SOURCE LINES 153-190 .. code-block:: Python @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], # argument names to use as an x-axis for the plot x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot line_vals=[ 'triton', 'torch-native', 'torch-jit', ], # possible values for `line_arg`` line_names=[ "Triton", "Torch (native)", "Torch (jit)", ], # label name for the lines styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles ylabel="GB/s", # label name for the y-axis plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` )) def benchmark(M, N, provider): x = torch.randn(M, N, device='cuda', dtype=torch.float32) quantiles = [0.5, 0.2, 0.8] if provider == 'torch-native': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1), quantiles=quantiles) if provider == 'triton': ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x), quantiles=quantiles) if provider == 'torch-jit': ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x), quantiles=quantiles) gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) benchmark.run(show_plots=True, print_data=True) .. image-sg:: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png :alt: 02 fused softmax :srcset: /getting-started/tutorials/images/sphx_glr_02-fused-softmax_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none softmax-performance: N Triton Torch (native) Torch (jit) 0 256.0 682.666643 744.727267 256.000001 1 384.0 877.714274 819.200021 315.076934 2 512.0 910.222190 910.222190 348.595735 3 640.0 975.238103 975.238103 372.363633 4 768.0 1068.521715 1023.999964 390.095241 .. ... ... ... ... 93 12160.0 1601.316858 1514.085562 464.898431 94 12288.0 1598.438956 1389.455824 465.344393 95 12416.0 1602.064538 1389.202822 463.067592 96 12544.0 1599.235121 1398.634134 464.055486 97 12672.0 1602.782573 1398.289673 463.433129 [98 rows x 4 columns] .. GENERATED FROM PYTHON SOURCE LINES 191-195 In the above plot, we can see that: - Triton is 4x faster than the Torch JIT. This confirms our suspicions that the Torch JIT does not do any fusion here. - Triton is noticeably faster than :code:`torch.softmax` -- in addition to being **easier to read, understand and maintain**. Note however that the PyTorch `softmax` operation is more general and will work on tensors of any shape. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 42.347 seconds) .. _sphx_glr_download_getting-started_tutorials_02-fused-softmax.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 02-fused-softmax.ipynb <02-fused-softmax.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 02-fused-softmax.py <02-fused-softmax.py>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_