.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "getting-started/tutorials/09-persistent-matmul.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_getting-started_tutorials_09-persistent-matmul.py: Persistent Matmul ===================== This script demonstrates persistent kernel implementations of matrix multiplication using Triton. Various matmul methods are included, such as naive, persistent, and TMA (Tensor Memory Accelerator) based approaches. The kernels support both FP16 and FP8 data types but the FP8 implementation is only available on CUDA devices with compute capability >= 9.0. Triton and cuBLAS implementations are benchmarked under different configurations and evaluated using the proton profiler. Users can pass command-line arguments to specify matrix dimensions and iteration steps flexibly. .. code-block:: bash # FP8 python 09-persistent-matmul.py --prec fp8 --K_range 128 1024 --K_step 128 # FP16 python 09-persistent-matmul.py --prec fp16 --K_range 128 1024 --K_step 128 Note that currently this tutorial will fail on devices with a small shared memory size, such as RTX-4090. .. GENERATED FROM PYTHON SOURCE LINES 21-719 .. rst-class:: sphx-glr-script-out .. code-block:: none M=32, N=32, K=32, verification naive vs: Torch: ... Torch: ✅ cuBLAS: ... cuBLAS: ✅ Persistent: ... Persistent: ✅ TMA (warp_specialize=False): ... TMA (warp_specialize=False): ⭕ TMA Persistent (warp_specialize=False): ... TMA Persistent (warp_specialize=False): ⭕ Tensor Descriptor Persistent (warp_specialize=False): ... Tensor Descriptor Persistent (warp_specialize=False): ⭕ M=8192, N=8192, K=512, verification naive vs: Torch: ... Torch: ✅ cuBLAS: ... cuBLAS: ✅ Persistent: ... Persistent: ✅ TMA (warp_specialize=False): ... TMA (warp_specialize=False): ⭕ TMA Persistent (warp_specialize=False): ... TMA Persistent (warp_specialize=False): ⭕ Tensor Descriptor Persistent (warp_specialize=False): ... Tensor Descriptor Persistent (warp_specialize=False): ⭕ Benchmarking cublas: ... Benchmarking cublas: done Benchmarking torch: ... Benchmarking torch: done Benchmarking naive: ... Benchmarking naive: done Benchmarking persistent: ... Benchmarking persistent: done 169.000 16264.975 ROOT ├─ 176.035 3903.739 cublas [M=8192, N=8192, K=512] │ └─ nan 3903.739 ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_32x5_tn ├─ 166.203 4134.670 matmul_kernel [M=8192, N=8192, K=512] ├─ 159.226 4315.852 matmul_kernel_persistent [M=8192, N=8192, K=512] └─ 175.721 3910.714 torch [M=8192, N=8192, K=512] └─ nan 3910.714 ampere_fp16_s16816gemm_fp16_128x128_ldg8_f2f_stages_32x5_tn | .. code-block:: Python import argparse import itertools import torch import triton import triton.language as tl import triton.profiler as proton from triton.tools.tensor_descriptor import TensorDescriptor from contextlib import contextmanager from typing import Optional if torch.cuda.is_available(): from triton._C.libtriton import nvidia cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) cublas = nvidia.cublas.CublasLt(cublas_workspace) else: cublas = None def is_cuda(): return triton.runtime.driver.active.get_current_target().backend == "cuda" def supports_tma(): return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 def supports_ws(): return is_cuda() and torch.cuda.get_device_capability()[0] >= 10 def _matmul_launch_metadata(grid, kernel, args): ret = {} M, N, K, WS = args["M"], args["N"], args["K"], args.get("WARP_SPECIALIZE", False) ws_str = "_ws" if WS else "" ret["name"] = f"{kernel.name}{ws_str} [M={M}, N={N}, K={K}]" if "c_ptr" in args: bytes_per_elem = args["c_ptr"].element_size() else: bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2 ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N) return ret HAS_TENSOR_DESC = supports_tma() and hasattr(tl, "make_tensor_descriptor") HAS_HOST_TENSOR_DESC = supports_tma() and hasattr(triton.tools.tensor_descriptor, "TensorDescriptor") HAS_WARP_SPECIALIZE = supports_ws() and HAS_TENSOR_DESC def matmul_get_configs(pre_hook=None): return [ triton.Config({'BLOCK_SIZE_M': BM, 'BLOCK_SIZE_N': BN, "BLOCK_SIZE_K" : BK, "GROUP_SIZE_M" : 8}, num_stages=s, num_warps=w, pre_hook=pre_hook) \ for BM in [128] \ for BN in [128, 256] \ for BK in [64,128] \ for s in ([3,4]) \ for w in [4,8] \ ] @triton.autotune( configs=matmul_get_configs(), key=["M", "N", "K"], ) @triton.jit(launch_metadata=_matmul_launch_metadata) def matmul_kernel(a_ptr, b_ptr, c_ptr, # M, N, K, # stride_am, stride_ak, # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) offs_am = tl.where(offs_am < M, offs_am, 0) offs_bn = tl.where(offs_bn < N, offs_bn, 0) offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) accumulator = tl.dot(a, b, accumulator) a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk if (c_ptr.dtype.element_ty == tl.float8e4nv): c = accumulator.to(tl.float8e4nv) else: c = accumulator.to(tl.float16) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) def matmul(a, b): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.dtype == b.dtype, "Incompatible dtypes" M, K = a.shape K, N = b.shape dtype = a.dtype c = torch.empty((M, N), device=a.device, dtype=dtype) # 1D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) matmul_kernel[grid]( a, b, c, # M, N, K, # a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # ) return c def matmul_tma_set_block_size_hook(nargs): EPILOGUE_SUBTILE = nargs.get("EPILOGUE_SUBTILE", False) BLOCK_M = nargs["BLOCK_SIZE_M"] BLOCK_N = nargs["BLOCK_SIZE_N"] BLOCK_K = nargs["BLOCK_SIZE_K"] nargs["a_desc"].block_shape = [BLOCK_M, BLOCK_K] nargs["b_desc"].block_shape = [BLOCK_N, BLOCK_K] if EPILOGUE_SUBTILE: nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N // 2] else: nargs["c_desc"].block_shape = [BLOCK_M, BLOCK_N] @triton.autotune( configs=matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook), key=["M", "N", "K", "WARP_SPECIALIZE"], ) @triton.jit(launch_metadata=_matmul_launch_metadata) def matmul_kernel_tma(a_desc, b_desc, c_desc, # M, N, K, # BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # FP8_OUTPUT: tl.constexpr, # WARP_SPECIALIZE: tl.constexpr, # ): dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m k_tiles = tl.cdiv(K, BLOCK_SIZE_K) offs_am = pid_m * BLOCK_SIZE_M offs_bn = pid_n * BLOCK_SIZE_N accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in tl.range(k_tiles, warp_specialize=WARP_SPECIALIZE): offs_k = k * BLOCK_SIZE_K a = a_desc.load([offs_am, offs_k]) b = b_desc.load([offs_bn, offs_k]) accumulator = tl.dot(a, b.T, accumulator) c = accumulator.to(dtype) offs_cm = pid_m * BLOCK_SIZE_M offs_cn = pid_n * BLOCK_SIZE_N c_desc.store([offs_cm, offs_cn], c) def matmul_tma(a, b, warp_specialize: bool): # Check constraints. assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed assert a.dtype == b.dtype, "Incompatible dtypes" M, K = a.shape N, K = b.shape dtype = a.dtype c = torch.empty((M, N), device=a.device, dtype=dtype) # A dummy block value that will be overwritten when we have the real block size dummy_block = [1, 1] a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) def grid(META): BLOCK_M = META["BLOCK_SIZE_M"] BLOCK_N = META["BLOCK_SIZE_N"] return (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), ) matmul_kernel_tma[grid]( a_desc, b_desc, c_desc, # M, N, K, # FP8_OUTPUT=dtype == torch.float8_e4m3fn, # WARP_SPECIALIZE=warp_specialize, # ) return c @triton.jit def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS): group_id = tile_id // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (tile_id % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m return pid_m, pid_n @triton.autotune( configs=matmul_get_configs(), key=["M", "N", "K"], ) @triton.jit(launch_metadata=_matmul_launch_metadata) def matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, # M, N, K, # stride_am, stride_ak, # stride_bk, stride_bn, # stride_cm, stride_cn, # BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # NUM_SMS: tl.constexpr, # ): start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n # NOTE: There is currently a bug in blackwell pipelining that means it can't handle a value being # used in both the prologue and epilogue, so we duplicate the counters as a work-around. tile_id_c = start_pid - NUM_SMS offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) num_pid_in_group = GROUP_SIZE_M * num_pid_n for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N offs_am = start_m + tl.arange(0, BLOCK_SIZE_M) offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N) offs_am = tl.where(offs_am < M, offs_am, 0) offs_bn = tl.where(offs_bn < N, offs_bn, 0) offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M) offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for ki in range(k_tiles): offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0) accumulator = tl.dot(a, b, accumulator) tile_id_c += NUM_SMS pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) if (c_ptr.dtype.element_ty == tl.float8e4nv): c = accumulator.to(tl.float8e4nv) else: c = accumulator.to(tl.float16) tl.store(c_ptrs, c, mask=c_mask) def matmul_persistent(a, b): # Check constraints. assert a.shape[1] == b.shape[0], "Incompatible dimensions" assert a.dtype == b.dtype, "Incompatible dtypes" NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count M, K = a.shape K, N = b.shape dtype = a.dtype # Allocates output. c = torch.empty((M, N), device=a.device, dtype=dtype) # 1D launch kernel where each block gets its own program. grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) matmul_kernel_persistent[grid]( a, b, c, # M, N, K, # a.stride(0), a.stride(1), # b.stride(0), b.stride(1), # c.stride(0), c.stride(1), # NUM_SMS=NUM_SMS, # ) return c def matmul_tma_persistent_get_configs(pre_hook=None): return [ triton.Config( { 'BLOCK_SIZE_M': BM, 'BLOCK_SIZE_N': BN, "BLOCK_SIZE_K": BK, "GROUP_SIZE_M": 8, "EPILOGUE_SUBTILE": SUBTILE }, num_stages=s, num_warps=w, pre_hook=pre_hook) # for BM in [128] # for BN in [128, 256] # for BK in [64, 128] # for s in ([2, 3, 4]) # for w in [4, 8] # for SUBTILE in [True, False] # ] @triton.autotune( configs=matmul_tma_persistent_get_configs(pre_hook=matmul_tma_set_block_size_hook), key=["M", "N", "K", "WARP_SPECIALIZE"], ) @triton.jit(launch_metadata=_matmul_launch_metadata) def matmul_kernel_tma_persistent(a_desc, b_desc, c_desc, # M, N, K, # BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # FP8_OUTPUT: tl.constexpr, # EPILOGUE_SUBTILE: tl.constexpr, # NUM_SMS: tl.constexpr, # WARP_SPECIALIZE: tl.constexpr, # ): dtype = tl.float8e4nv if FP8_OUTPUT else tl.float16 start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n tile_id_c = start_pid - NUM_SMS num_pid_in_group = GROUP_SIZE_M * num_pid_n # Enable warp specialization to leverage async warp scheduling in the GPU. # FIXME: This only works on Blackwell right now. On older GPUs, this will # use software pipelining. for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE): pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) offs_am = pid_m * BLOCK_SIZE_M offs_bn = pid_n * BLOCK_SIZE_N accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for ki in range(k_tiles): offs_k = ki * BLOCK_SIZE_K a = a_desc.load([offs_am, offs_k]) b = b_desc.load([offs_bn, offs_k]) accumulator = tl.dot(a, b.T, accumulator) tile_id_c += NUM_SMS pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) offs_am_c = pid_m * BLOCK_SIZE_M offs_bn_c = pid_n * BLOCK_SIZE_N # Epilogue subtiling is a technique to break our computation and stores into multiple pieces # By subtiling we can reduce shared memory consumption by the epilogue and instead use that # memory to increase our stage count. # In this case we partition the accumulator into 2 BLOCK_SIZE_M x BLOCK_SIZE_N // 2 tensors if EPILOGUE_SUBTILE: acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) acc = tl.permute(acc, (0, 2, 1)) acc0, acc1 = tl.split(acc) c0 = acc0.to(dtype) c_desc.store([offs_am_c, offs_bn_c], c0) c1 = acc1.to(dtype) c_desc.store([offs_am_c, offs_bn_c + BLOCK_SIZE_N // 2], c1) else: accumulator = accumulator.to(dtype) c_desc.store([offs_am_c, offs_bn_c], accumulator) def matmul_tma_persistent(a, b, warp_specialize: bool): # Check constraints. assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed assert a.dtype == b.dtype, "Incompatible dtypes" M, K = a.shape N, K = b.shape dtype = a.dtype c = torch.empty((M, N), device=a.device, dtype=dtype) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count # A dummy block value that will be overwritten when we have the real block size dummy_block = [1, 1] a_desc = TensorDescriptor(a, a.shape, a.stride(), dummy_block) b_desc = TensorDescriptor(b, b.shape, b.stride(), dummy_block) c_desc = TensorDescriptor(c, c.shape, c.stride(), dummy_block) def grid(META): nonlocal a_desc, b_desc, c_desc BLOCK_M = META["BLOCK_SIZE_M"] BLOCK_N = META["BLOCK_SIZE_N"] return (min( NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), ), ) matmul_kernel_tma_persistent[grid]( a_desc, b_desc, c_desc, # M, N, K, # FP8_OUTPUT=dtype == torch.float8_e4m3fn, # NUM_SMS=NUM_SMS, # WARP_SPECIALIZE=warp_specialize, # ) return c @triton.autotune( configs=matmul_tma_persistent_get_configs(), key=["M", "N", "K", "WARP_SPECIALIZE"], ) @triton.jit(launch_metadata=_matmul_launch_metadata) def matmul_kernel_descriptor_persistent(a_ptr, b_ptr, c_ptr, # M, N, K, # BLOCK_SIZE_M: tl.constexpr, # BLOCK_SIZE_N: tl.constexpr, # BLOCK_SIZE_K: tl.constexpr, # GROUP_SIZE_M: tl.constexpr, # EPILOGUE_SUBTILE: tl.constexpr, # NUM_SMS: tl.constexpr, # WARP_SPECIALIZE: tl.constexpr, # ): # Matmul using TMA and device-side descriptor creation dtype = c_ptr.dtype.element_ty start_pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) k_tiles = tl.cdiv(K, BLOCK_SIZE_K) num_tiles = num_pid_m * num_pid_n a_desc = tl.make_tensor_descriptor( a_ptr, shape=[M, K], strides=[K, 1], block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], ) b_desc = tl.make_tensor_descriptor( b_ptr, shape=[N, K], strides=[K, 1], block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], ) c_desc = tl.make_tensor_descriptor( c_ptr, shape=[M, N], strides=[N, 1], block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N if not EPILOGUE_SUBTILE else BLOCK_SIZE_N // 2], ) # tile_id_c is used in the epilogue to break the dependency between # the prologue and the epilogue tile_id_c = start_pid - NUM_SMS num_pid_in_group = GROUP_SIZE_M * num_pid_n for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True, warp_specialize=WARP_SPECIALIZE): pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) offs_am = pid_m * BLOCK_SIZE_M offs_bn = pid_n * BLOCK_SIZE_N accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for ki in range(k_tiles): offs_k = ki * BLOCK_SIZE_K a = a_desc.load([offs_am, offs_k]) b = b_desc.load([offs_bn, offs_k]) accumulator = tl.dot(a, b.T, accumulator) tile_id_c += NUM_SMS pid_m, pid_n = _compute_pid(tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) offs_cm = pid_m * BLOCK_SIZE_M offs_cn = pid_n * BLOCK_SIZE_N if EPILOGUE_SUBTILE: acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) acc = tl.permute(acc, (0, 2, 1)) acc0, acc1 = tl.split(acc) c0 = acc0.to(dtype) c_desc.store([offs_cm, offs_cn], c0) c1 = acc1.to(dtype) c_desc.store([offs_cm, offs_cn + BLOCK_SIZE_N // 2], c1) else: c = accumulator.to(dtype) c_desc.store([offs_cm, offs_cn], c) def matmul_descriptor_persistent(a, b, warp_specialize: bool): # Check constraints. assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed assert a.dtype == b.dtype, "Incompatible dtypes" M, K = a.shape N, K = b.shape dtype = a.dtype c = torch.empty((M, N), device=a.device, dtype=dtype) NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count # TMA descriptors require a global memory allocation def alloc_fn(size: int, alignment: int, stream: Optional[int]): return torch.empty(size, device="cuda", dtype=torch.int8) triton.set_allocator(alloc_fn) grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"])), ) matmul_kernel_descriptor_persistent[grid]( a, b, c, # M, N, K, # NUM_SMS=NUM_SMS, # WARP_SPECIALIZE=warp_specialize, # ) return c def cublas_matmul(a, b): # Check constraints. assert a.shape[1] == b.shape[1], "Incompatible dimensions" # b is transposed M, K = a.shape N, K = b.shape dtype = a.dtype c = torch.empty((M, N), device=a.device, dtype=dtype) bytes_per_elem = a.element_size() flops_str = f"flops{bytes_per_elem * 8}" with proton.scope(f"cublas [M={M}, N={N}, K={K}]", {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): cublas.matmul(a, b, c) return c def torch_matmul(a, b): M, K = a.shape N, K = b.shape bytes_per_elem = a.element_size() flops_str = f"flops{bytes_per_elem * 8}" with proton.scope(f"torch [M={M}, N={N}, K={K}]", {"bytes": bytes_per_elem * (M * K + N * K + M * N), flops_str: 2. * M * N * K}): c = torch.matmul(a, b.T) return c @contextmanager def proton_context(): proton.activate(0) try: yield finally: proton.deactivate(0) def bench_fn(label, reps, warmup_reps, fn, *args): print(f"Benchmarking {label}: ...", end="") for _ in range(warmup_reps): fn(*args) with proton_context(): for _ in range(reps): fn(*args) print(f"\rBenchmarking {label}: done") def bench(K, dtype, reps=10000, warmup_reps=10000): M = 8192 N = 8192 a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) b = b.T.contiguous() if cublas is not None: bench_fn("cublas", reps, warmup_reps, cublas_matmul, a, b) if dtype == torch.float16: bench_fn("torch", reps, warmup_reps, torch_matmul, a, b) bench_fn("naive", reps, warmup_reps, matmul, a, b.T) bench_fn("persistent", reps, warmup_reps, matmul_persistent, a, b.T) warp_specialize = [False, True] if HAS_WARP_SPECIALIZE else [False] for ws in warp_specialize: ws_str = "_ws" if ws else "" if HAS_HOST_TENSOR_DESC: bench_fn(f"tma_persistent{ws_str}", reps, warmup_reps, lambda a, b: matmul_tma_persistent(a, b, ws), a, b) bench_fn(f"tma{ws_str}", reps, warmup_reps, lambda a, b: matmul_tma(a, b, ws), a, b) if HAS_TENSOR_DESC: bench_fn(f"descriptor_persistent{ws_str}", reps, warmup_reps, lambda a, b: matmul_descriptor_persistent(a, b, ws), a, b) def run_test(expect, fn, a, b, label, enabled=True): print(f" {label}: ...", end="") if enabled: actual = fn(a, b) passed = torch.allclose(expect, actual.to(expect.dtype), atol=1.0) icon = "✅" if passed else "❌" else: icon = "⭕" print(f"\r {label}: {icon} ") def validate(M, N, K, dtype): print(f"{M=}, {N=}, {K=}, verification naive vs: ") a = torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype) b = torch.randn((K, N), device="cuda", dtype=torch.float16).to(dtype) b = b.T.contiguous() naive_result = matmul(a, b.T).to(torch.float16) run_test(naive_result, torch_matmul, a, b, "Torch", enabled=dtype == torch.float16) run_test(naive_result, cublas_matmul, a, b, "cuBLAS", enabled=cublas is not None) run_test(naive_result, matmul_persistent, a, b.T, "Persistent") kernels = [ (matmul_tma, "TMA", HAS_HOST_TENSOR_DESC), (matmul_tma_persistent, "TMA Persistent", HAS_HOST_TENSOR_DESC), (matmul_descriptor_persistent, "Tensor Descriptor Persistent", HAS_TENSOR_DESC), ] warp_specialize = [False, True] if HAS_WARP_SPECIALIZE else [False] for (kernel, label, enabled), warp_specialize in itertools.product(kernels, warp_specialize): label = f"{label} (warp_specialize={warp_specialize})" enabled = enabled and (not warp_specialize or HAS_TENSOR_DESC) run_test(naive_result, lambda a, b: kernel(a, b, warp_specialize), a, b, label, enabled) print() def show_profile(precision, profile_name): import triton.profiler.viewer as proton_viewer metric_names = ["time/ms"] if precision == 'fp8': metric_names = ["tflop8/s"] + metric_names elif precision == 'fp16': metric_names = ["tflop16/s"] + metric_names file_name = f"{profile_name}.hatchet" tree, metrics = proton_viewer.parse(metric_names, file_name) proton_viewer.print_tree(tree, metrics) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-K", type=int, required=False, default=512) parser.add_argument("--K_range", type=int, nargs=2) parser.add_argument("--K_step", type=int, default=512) parser.add_argument("--prec", type=str, choices=["fp8", "fp16"], default="fp16") args = parser.parse_args() if args.prec == 'fp8' and (not hasattr(torch, "float8_e4m3fn") or not is_cuda()): print("This example requires CUDA with fp8 support.") else: dtype = torch.float8_e4m3fn if args.prec == 'fp8' else torch.float16 if args.K and args.K_range is None: args.K_range = [args.K, args.K] args.K_step = 1 # doesn't matter as long as it's not 0 torch.manual_seed(0) validate(32, 32, 32, dtype) validate(8192, 8192, args.K_range[0], dtype) proton.start("matmul", hook="triton") proton.deactivate() for K in range(args.K_range[0], args.K_range[1] + 1, args.K_step): bench(K, dtype) proton.finalize() show_profile(args.prec, "matmul") .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 11.051 seconds) .. _sphx_glr_download_getting-started_tutorials_09-persistent-matmul.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 09-persistent-matmul.ipynb <09-persistent-matmul.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 09-persistent-matmul.py <09-persistent-matmul.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 09-persistent-matmul.zip <09-persistent-matmul.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_