(example_matmul-multicta)= # Matmul Multicta This example can be found at ``python/examples/gluon/03-matmul-multicta.py``. ````python import argparse import pytest import torch import triton from triton.experimental import gluon from triton.experimental.gluon import language as gl from triton.experimental.gluon.language.nvidia.blackwell import ( TensorMemoryLayout, allocate_tensor_memory, clc, tcgen05_commit, tcgen05_mma, tcgen05_mma_barrier_count, tensor_memory_descriptor, ) from triton.experimental.gluon.language.nvidia.hopper import mbarrier, tma from triton.experimental.gluon.nvidia.hopper import TensorDescriptor def is_blackwell(): if not torch.cuda.is_available(): return False target = triton.runtime.driver.active.get_current_target() return target.backend == "cuda" and torch.cuda.get_device_capability()[0] == 10 def as_gl_dtype(torch_dtype): if torch_dtype == torch.float16: return gl.float16 if torch_dtype == torch.bfloat16: return gl.bfloat16 if torch_dtype == torch.float32: return gl.float32 raise ValueError(f"Unsupported dtype for Gluon layout: {torch_dtype}") @gluon.constexpr_function def get_split_dim(cga_layout, dim): return 1 << sum(b[dim] != 0 for b in cga_layout) def get_epilogue_size_n(block_m, block_n, cga_layout): """ We can't split a layout along one of the first 4 warps or along a CTA as each of these has their own address space. It would be possible to split it if we use a smaller BLOCK_N for both the TMA and MMA instructions but this is NYI. """ # We can't split the layout along N as N on M=64 2CTA layouts # the basis (0, TileN) is owned by the second warp basis! if block_m == 64 and cga_layout: return block_n # We can't split the layout along N as the last basis along N # is owned by a different CTA! if get_split_dim(cga_layout, 1) > 1: return block_n return 32 def matmul_get_configs(pre_hook=None): return [ triton.Config( { "BLOCK_SIZE_M": BM, "BLOCK_SIZE_N": BN, "BLOCK_SIZE_K": BK, "GRID_MINOR_DIM": minor_dim, "GRID_TILE_WIDTH": grid_tile_width, "STAGES": stages, "ACC_STAGES": acc_stages, "EPILOGUE_SIZE_N": get_epilogue_size_n(BM, BN, cga_layout), "SUBTILE_STAGES": subtile_stages, "CGA_LAYOUT": cga_layout, }, num_warps=4, num_ctas=2**len(cga_layout), pre_hook=pre_hook, ) for BM in (64, 128) for BN in (128, 256, 512) for BK in (64, 128) for minor_dim in (0, 1) for grid_tile_width in (4, 8, 16) for stages in (2, 4, 6) for acc_stages in (2, ) for subtile_stages in (4, ) for cga_layout in ((), ((1, 0), ), ((1, 0), (2, 0))) if BN // get_split_dim(cga_layout, 1) <= 256 # Trim some configs with too large a tile if not (BN == 512 and len(cga_layout) == 0) ] def matmul_tma_set_block_size_hook(nargs): block_m = nargs["BLOCK_SIZE_M"] block_n = nargs["BLOCK_SIZE_N"] block_k = nargs["BLOCK_SIZE_K"] epilogue_size_n = nargs["EPILOGUE_SIZE_N"] cga_layout = nargs["CGA_LAYOUT"] tile_m = block_m * get_split_dim(cga_layout, 0) nargs["a_desc"].block_shape = [tile_m, block_k] nargs["b_desc"].block_shape = [block_k, block_n] nargs["c_desc"].block_shape = [tile_m, epilogue_size_n] def get_cga_layout(layout, op_idx): assert op_idx in (0, 1) if not layout: return layout # 2CTA performs an outer product so bases are [1, 0] and [0, 1] assert layout[0] == (1, 0) first = (1, 0) if op_idx == 0 else (0, 1) # Broadcast along K (the reduction dimension) # We multiply by 2 for op_idx == 1, as we have added the (0, 1) basis. def broadcast(b): return (b[0], 0) if op_idx == 0 else (0, 2 * b[1]) return (first, *map(broadcast, layout[1:])) cga_layout_a = get_cga_layout(cga_layout, 0) cga_layout_b = get_cga_layout(cga_layout, 1) cga_layout_c = cga_layout for desc, cga_layout in zip(("a_desc", "b_desc", "c_desc"), (cga_layout_a, cga_layout_b, cga_layout_c)): nargs[desc].layout = gl.NVMMASharedLayout.get_default_for( nargs[desc].block_shape, as_gl_dtype(nargs[desc].base.dtype), cga_layout=cga_layout, ) # From Pallas / CUTLASS @gluon.jit def _planar_snake(lin_idx, m_tiles, n_tiles, minor_dim: gl.constexpr, tile_width: gl.constexpr): major_size = n_tiles if minor_dim == 0 else m_tiles minor_size = m_tiles if minor_dim == 0 else n_tiles full_minor_tiles = minor_size // tile_width full_minor_size = full_minor_tiles * tile_width full_elements = full_minor_tiles * tile_width * major_size minor_tile_idx = lin_idx // (tile_width * major_size) full_minor_within = lin_idx % tile_width full_major_within = (lin_idx // tile_width) % major_size full_minor = minor_tile_idx * tile_width + full_minor_within full_major = gl.where((minor_tile_idx % 2) == 0, full_major_within, major_size - 1 - full_major_within) partial_width = minor_size - full_minor_size partial_width = gl.where(partial_width > 0, partial_width, 1) partial_lin = lin_idx - full_elements partial_minor_within = partial_lin % partial_width partial_major_within = (partial_lin // partial_width) % major_size partial_minor = minor_tile_idx * tile_width + partial_minor_within partial_major = gl.where((minor_tile_idx % 2) == 0, partial_major_within, major_size - 1 - partial_major_within) in_full_tile = lin_idx < full_elements minor = gl.where(in_full_tile, full_minor, partial_minor) major = gl.where(in_full_tile, full_major, partial_major) if minor_dim == 0: return minor, major return major, minor @gluon.aggregate class Counter: index: gl.tensor phase: gl.tensor num_barriers: gl.constexpr @gluon.jit def create(phase, num_barriers: gl.constexpr): return Counter(gl.to_tensor(0), gl.to_tensor(phase), num_barriers) @gluon.must_use_result @gluon.jit def next(self, pred=True): incr = self.index + gl.where(pred, 1, 0) rollover = incr == self.num_barriers index = gl.where(rollover, 0, incr) phase = gl.where(rollover, self.phase ^ 1, self.phase) return Counter(index, phase, self.num_barriers) @gluon.aggregate class ClcTileSchedulerConsumer: has_work: gl.tensor tile_id: gl.tensor pid_m: gl.tensor pid_n: gl.tensor num_pid_m: gl.tensor num_pid_n: gl.tensor TILE_M: gl.constexpr TILE_N: gl.constexpr MINOR_DIM: gl.constexpr GRID_TILE_WIDTH: gl.constexpr clc_result_buffers: gl.shared_memory_descriptor clc_barriers: gl.shared_memory_descriptor clc_planar_pid_buffers: gl.shared_memory_descriptor clc_planar_ready_bars: gl.shared_memory_descriptor clc_consumed_bars: gl.shared_memory_descriptor counter: Counter consumed_counter: Counter @gluon.jit def initialize(M, N, TILE_M: gl.constexpr, TILE_N: gl.constexpr, MINOR_DIM: gl.constexpr, GRID_TILE_WIDTH: gl.constexpr, clc_result_buffers, clc_barriers, clc_planar_pid_buffers, clc_planar_ready_bars, clc_consumed_bars): tile_id = gl.program_id(axis=0) num_pid_m = gl.cdiv(M, TILE_M) num_pid_n = gl.cdiv(N, TILE_N) pid_m, pid_n = _planar_snake(tile_id, num_pid_m, num_pid_n, MINOR_DIM, GRID_TILE_WIDTH) has_work = gl.to_tensor(True) counter = Counter.create(0, clc_barriers.shape[0]) consumed_counter = Counter.create(0, clc_barriers.shape[0]) return ClcTileSchedulerConsumer( has_work, tile_id, pid_m, pid_n, num_pid_m, num_pid_n, TILE_M, TILE_N, MINOR_DIM, GRID_TILE_WIDTH, clc_result_buffers, clc_barriers, clc_planar_pid_buffers, clc_planar_ready_bars, clc_consumed_bars, counter, consumed_counter, ) @gluon.jit def get_offsets(self): return self.pid_m * self.TILE_M, self.pid_n * self.TILE_N @gluon.jit def step(self, iteration): # The 0-th iteration uses the program_id as the tile_id. # At the end of each iteration we prefetch the next tile. # As such we must signal the consumed slot at the end of # each iteration skipping the first one. consumed_counter = self.consumed_counter if iteration > 0: mbarrier.arrive(self.clc_consumed_bars.index(consumed_counter.index)) consumed_counter = consumed_counter.next() counter = self.counter barrier = self.clc_barriers.index(counter.index) result = self.clc_result_buffers.index(counter.index) mbarrier.wait(barrier, counter.phase) clc_res = clc.load_result(result) mbarrier.wait(self.clc_planar_ready_bars.index(counter.index), counter.phase) planar_slot = self.clc_planar_pid_buffers.index(counter.index) planar_layout: gl.constexpr = gl.BlockedLayout([1], [32], [gl.num_warps()], [0], [[0]] * (gl.num_ctas().bit_length() - 1)) packed_pid = planar_slot.load(planar_layout).reshape([]) pid_m = ((packed_pid >> 32) & 0xFFFFFFFF).to(gl.int32) pid_n = (packed_pid & 0xFFFFFFFF).to(gl.int32) has_work = clc_res.is_canceled() tile_id = self.tile_id if has_work: tile_id = clc_res.program_id(0) return ClcTileSchedulerConsumer( has_work, tile_id, pid_m, pid_n, self.num_pid_m, self.num_pid_n, self.TILE_M, self.TILE_N, self.MINOR_DIM, self.GRID_TILE_WIDTH, self.clc_result_buffers, self.clc_barriers, self.clc_planar_pid_buffers, self.clc_planar_ready_bars, self.clc_consumed_bars, counter.next(), consumed_counter, ) @gluon.aggregate class PartitionArgs: a_desc: tma.tensor_descriptor b_desc: tma.tensor_descriptor c_desc: tma.tensor_descriptor a_bufs: gl.shared_memory_descriptor b_bufs: gl.shared_memory_descriptor load_empty_bars: gl.shared_memory_descriptor load_ready_bars: gl.shared_memory_descriptor acc_bufs: tensor_memory_descriptor acc_empty_bars: gl.shared_memory_descriptor acc_ready_bars: gl.shared_memory_descriptor clc_result_buffers: gl.shared_memory_descriptor clc_barriers: gl.shared_memory_descriptor clc_planar_pid_buffers: gl.shared_memory_descriptor clc_planar_ready_bars: gl.shared_memory_descriptor clc_consumed_bars: gl.shared_memory_descriptor MINOR_DIM: gl.constexpr GRID_TILE_WIDTH: gl.constexpr SUBTILE_STAGES: gl.constexpr @gluon.jit def get_clc_consumer(self): return ClcTileSchedulerConsumer.initialize( self.c_desc.shape[0], self.c_desc.shape[1], self.a_desc.block_shape[0], self.b_desc.block_shape[1], self.MINOR_DIM, self.GRID_TILE_WIDTH, self.clc_result_buffers, self.clc_barriers, self.clc_planar_pid_buffers, self.clc_planar_ready_bars, self.clc_consumed_bars, ) @gluon.jit def matmul_clc_partition(p): TILE_M: gl.constexpr = p.a_desc.block_shape[0] TILE_N: gl.constexpr = p.b_desc.block_shape[1] has_work = gl.to_tensor(True) num_pid_m = gl.cdiv(p.c_desc.shape[0], TILE_M) num_pid_n = gl.cdiv(p.c_desc.shape[1], TILE_N) state = Counter.create(0, p.clc_barriers.shape[0]) consumed_state = Counter.create(1, p.clc_barriers.shape[0]) ACC_STAGES: gl.constexpr = p.clc_barriers.shape[0] i = 0 while has_work: # Reuse the slot only after all consumer partitions signaled consumed. mbarrier.wait(p.clc_consumed_bars.index(consumed_state.index), consumed_state.phase, pred=(i >= ACC_STAGES)) barrier = p.clc_barriers.index(state.index) result = p.clc_result_buffers.index(state.index) # 16: clc.try_cancel has a `.b128` modifier mbarrier.expect(barrier, 16) clc.try_cancel(result, barrier) mbarrier.wait(barrier, state.phase) clc_res = clc.load_result(result) has_work = clc_res.is_canceled() pid_m = gl.to_tensor(0) pid_n = gl.to_tensor(0) if has_work: tile_id = clc_res.program_id(0) pid_m, pid_n = _planar_snake(tile_id, num_pid_m, num_pid_n, p.MINOR_DIM, p.GRID_TILE_WIDTH) packed_pid = (pid_m.to(gl.int64) << 32) | (pid_n.to(gl.int64) & 0xFFFFFFFF) planar_slot = p.clc_planar_pid_buffers.index(state.index) planar_layout: gl.constexpr = gl.BlockedLayout([1], [32], [gl.num_warps()], [0], [[0]] * (gl.num_ctas().bit_length() - 1)) planar_slot.store(gl.full([1], packed_pid, gl.int64, layout=planar_layout)) mbarrier.arrive(p.clc_planar_ready_bars.index(state.index)) state = state.next() consumed_state = consumed_state.next() i += 1 @gluon.jit def matmul_load_partition(p): BLOCK_K: gl.constexpr = p.a_desc.block_shape[1] K = p.a_desc.shape[1] concurrent_loads: gl.constexpr = p.load_ready_bars.shape[0] state = Counter.create(1, concurrent_loads) scheduler = p.get_clc_consumer() i = 0 while scheduler.has_work: off_m, off_n = scheduler.get_offsets() for k in range(0, K, BLOCK_K): pred = (i > 0) or (k >= BLOCK_K * concurrent_loads) mbarrier.wait(p.load_empty_bars.index(state.index), state.phase, pred=pred) bar = p.load_ready_bars.index(state.index) mbarrier.expect(bar, p.a_desc.nbytes_per_cta + p.b_desc.nbytes_per_cta) tma.async_load(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index), multicast=True) tma.async_load(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index), multicast=True) state = state.next() scheduler = scheduler.step(i) i += 1 @gluon.jit def matmul_mma_partition(p): BLOCK_K: gl.constexpr = p.a_desc.block_shape[1] K = p.a_desc.shape[1] ACC_STAGES: gl.constexpr = p.acc_empty_bars.shape[0] load_state = Counter.create(0, p.load_empty_bars.shape[0]) acc_state = Counter.create(1, ACC_STAGES) scheduler = p.get_clc_consumer() i = 0 while scheduler.has_work: acc_buf = p.acc_bufs.index(acc_state.index) mbarrier.wait(p.acc_empty_bars.index(acc_state.index), acc_state.phase, pred=(i >= ACC_STAGES)) use_acc = False for k in range(0, K, BLOCK_K): mbarrier.wait(p.load_ready_bars.index(load_state.index), load_state.phase) tcgen05_mma(p.a_bufs.index(load_state.index), p.b_bufs.index(load_state.index), acc_buf, use_acc=use_acc, multicast=True, mbarriers=[p.load_empty_bars.index(load_state.index)]) load_state = load_state.next() use_acc = True tcgen05_commit(p.acc_ready_bars.index(acc_state.index), descs=[p.a_bufs.index(0), p.b_bufs.index(0)]) acc_state = acc_state.next() scheduler = scheduler.step(i) i += 1 @gluon.jit def matmul_epilogue_partition(p): TILE_M: gl.constexpr = p.a_desc.block_shape[0] TILE_N: gl.constexpr = p.b_desc.block_shape[1] SPLIT_TILE_N: gl.constexpr = p.c_desc.block_shape[1] # Separate knobs: SUBTILE_STAGES controls shared-memory usage, # and SUBTILE_FACTOR is the maximum number of subtiles into which we can split the tile, # which might be too large to fit within shared-memory limits. SUBTILE_FACTOR: gl.constexpr = TILE_N // SPLIT_TILE_N SUBTILE_STAGES: gl.constexpr = p.SUBTILE_STAGES ACC_STAGES: gl.constexpr = p.acc_empty_bars.shape[0] dtype: gl.constexpr = p.c_desc.dtype acc_state = Counter.create(0, ACC_STAGES) acc_smems = gl.allocate_shared_memory(dtype, [SUBTILE_STAGES, TILE_M, SPLIT_TILE_N], p.c_desc.layout) sub_acc_state = Counter.create(0, SUBTILE_STAGES) scheduler = p.get_clc_consumer() i = 0 while scheduler.has_work: off_m, off_n = scheduler.get_offsets() mbarrier.wait(p.acc_ready_bars.index(acc_state.index), acc_state.phase) acc_buf = p.acc_bufs.index(acc_state.index) for s in gl.static_range(SUBTILE_FACTOR): acc_sub = acc_buf.slice(SPLIT_TILE_N * s, SPLIT_TILE_N) acc_smem = acc_smems.index(sub_acc_state.index) acc = acc_sub.load().to(dtype) tma.store_wait(pendings=SUBTILE_STAGES - 1) acc_smem.store(acc) tma.async_copy_shared_to_global(p.c_desc, [off_m, off_n + SPLIT_TILE_N * s], acc_smem) sub_acc_state = sub_acc_state.next() # Signal that the accumulator slot can be reused only after all stores are done. mbarrier.arrive(p.acc_empty_bars.index(acc_state.index)) acc_state = acc_state.next() scheduler = scheduler.step(i) i += 1 @gluon.jit def _matmul_kernel( a_desc, b_desc, c_desc, M, N, K, BLOCK_SIZE_M: gl.constexpr, BLOCK_SIZE_N: gl.constexpr, BLOCK_SIZE_K: gl.constexpr, GRID_MINOR_DIM: gl.constexpr, GRID_TILE_WIDTH: gl.constexpr, STAGES: gl.constexpr, ACC_STAGES: gl.constexpr, CGA_LAYOUT: gl.constexpr, EPILOGUE_SIZE_N: gl.constexpr, SUBTILE_STAGES: gl.constexpr, ): BLOCK_M: gl.constexpr = a_desc.block_shape[0] BLOCK_N: gl.constexpr = b_desc.block_shape[1] TWO_CTAS: gl.constexpr = gl.num_ctas() > 1 N_PARTITIONS: gl.constexpr = 4 dtype: gl.constexpr = a_desc.dtype a_bufs = gl.allocate_shared_memory(dtype, [STAGES] + a_desc.block_shape, a_desc.layout) b_bufs = gl.allocate_shared_memory(dtype, [STAGES] + b_desc.block_shape, b_desc.layout) # Number of CTAs that will arrive on the barrier from a tcgen05_commit after an MMA instruction mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True) # Equiv. consumed_barrier. Barrier TCGEN05 MMA -> Load TMA load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES) # Equiv. ab_tma_barrier. Barrier Load TMA -> TCGEN05 MMA load_ready_bars = mbarrier.allocate_mbarrier(batch=STAGES, two_ctas=TWO_CTAS) for i in gl.static_range(STAGES): mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count) mbarrier.init(load_ready_bars.index(i), count=1) tmem_layout: gl.constexpr = TensorMemoryLayout( [BLOCK_SIZE_M, BLOCK_N // get_split_dim(CGA_LAYOUT, 1)], col_stride=1, cga_layout=CGA_LAYOUT, two_ctas=TWO_CTAS, ) acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, BLOCK_M, BLOCK_N], tmem_layout) # Equiv. store_done_barrier. Barrier Store TMA -> TCGEN05 MMA acc_empty_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=TWO_CTAS) # Equiv. mma_done_barrier. Barrier TCGEN05 MMA -> Store TMA acc_ready_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES) for i in gl.static_range(ACC_STAGES): mbarrier.init(acc_empty_bars.index(i), count=1) mbarrier.init(acc_ready_bars.index(i), count=mma_barrier_count) clc_barriers = mbarrier.allocate_mbarrier(batch=ACC_STAGES) clc_planar_ready_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES) clc_consumed_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=TWO_CTAS) for i in gl.static_range(ACC_STAGES): mbarrier.init(clc_barriers.index(i), count=1) mbarrier.init(clc_planar_ready_bars.index(i), count=1) # Every partition but itself arrives on the barrier mbarrier.init(clc_consumed_bars.index(i), count=N_PARTITIONS - 1) cga_layout: gl.constexpr = [[0]] * (gl.num_ctas().bit_length() - 1) clc_layout: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, [0], cga_layout=cga_layout) clc_result_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 2], clc_layout) clc_planar_pid_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 1], clc_layout) p = PartitionArgs( a_desc, b_desc, c_desc, a_bufs, b_bufs, load_empty_bars, load_ready_bars, acc_bufs, acc_empty_bars, acc_ready_bars, clc_result_buffers, clc_barriers, clc_planar_pid_buffers, clc_planar_ready_bars, clc_consumed_bars, GRID_MINOR_DIM, GRID_TILE_WIDTH, SUBTILE_STAGES, ) gl.warp_specialize([ (matmul_epilogue_partition, (p, )), (matmul_load_partition, (p, )), (matmul_mma_partition, (p, )), (matmul_clc_partition, (p, )), ], [1, 1, 1], [24, 24, 24]) matmul_kernel = triton.autotune( configs=matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook), key=["M", "N", "K"], )(_matmul_kernel) def matmul_with_config( a, b, out=None, *, block_size_m, block_size_n, block_size_k, grid_minor_dim, grid_tile_width, stages, acc_stages, cga_layout, epilogue_size_n, subtile_stages, ): if block_size_n // get_split_dim(cga_layout, 1) > 256: raise ValueError( f"cga_layout={list(cga_layout)} only supports BLOCK_SIZE_N <= {256 * get_split_dim(cga_layout, 1)}") M, K = a.shape K1, N = b.shape if K != K1: raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}") if a.dtype != torch.float16 or b.dtype != torch.float16: raise ValueError("matmul only supports fp16 inputs") if out is None: c = torch.empty((M, N), device=a.device, dtype=a.dtype) else: if out.shape != (M, N): raise ValueError(f"Output has invalid shape {out.shape}, expected {(M, N)}") if out.device != a.device or out.dtype != a.dtype: raise ValueError("Output must match input device and dtype") c = out dummy_block = [1, 1] dummy_layout = gl.NVMMASharedLayout.get_default_for(dummy_block, gl.float16) a_desc = TensorDescriptor.from_tensor(a, dummy_block, dummy_layout) b_desc = TensorDescriptor.from_tensor(b, dummy_block, dummy_layout) c_desc = TensorDescriptor.from_tensor(c, dummy_block, dummy_layout) matmul_tma_set_block_size_hook({ "a_desc": a_desc, "b_desc": b_desc, "c_desc": c_desc, "BLOCK_SIZE_M": block_size_m, "BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k, "GRID_MINOR_DIM": grid_minor_dim, "GRID_TILE_WIDTH": grid_tile_width, "STAGES": stages, "ACC_STAGES": acc_stages, "CGA_LAYOUT": cga_layout, "EPILOGUE_SIZE_N": epilogue_size_n, }) def grid(meta): tile_m = meta["BLOCK_SIZE_M"] * (2 if bool(meta["CGA_LAYOUT"]) else 1) tile_n = meta["BLOCK_SIZE_N"] num_tiles = triton.cdiv(M, tile_m) * triton.cdiv(N, tile_n) return (num_tiles, ) _matmul_kernel[grid]( a_desc, b_desc, c_desc, M, N, K, block_size_m, block_size_n, block_size_k, grid_minor_dim, grid_tile_width, stages, acc_stages, cga_layout, epilogue_size_n, subtile_stages, num_warps=4, num_ctas=2**len(cga_layout), ) return c def matmul(a, b): M, K = a.shape K1, N = b.shape if K != K1: raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}") if a.dtype != torch.float16 or b.dtype != torch.float16: raise ValueError("matmul only supports fp16 inputs") c = torch.empty((M, N), device=a.device, dtype=a.dtype) dummy_block = [1, 1] dummy_layout = gl.NVMMASharedLayout.get_default_for(dummy_block, gl.float16) a_desc = TensorDescriptor.from_tensor(a, dummy_block, dummy_layout) b_desc = TensorDescriptor.from_tensor(b, dummy_block, dummy_layout) c_desc = TensorDescriptor.from_tensor(c, dummy_block, dummy_layout) def grid(meta): tile_m = meta["BLOCK_SIZE_M"] * (2 if bool(meta["CGA_LAYOUT"]) else 1) tile_n = meta["BLOCK_SIZE_N"] num_tiles = triton.cdiv(M, tile_m) * triton.cdiv(N, tile_n) return (num_tiles, ) matmul_kernel[grid](a_desc, b_desc, c_desc, M, N, K) return c # Subset of matmul_get_configs @pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell") @pytest.mark.parametrize("BLOCK_SIZE_M", [64, 128]) @pytest.mark.parametrize("BLOCK_SIZE_N", [128, 256]) @pytest.mark.parametrize("BLOCK_SIZE_K", [64, 128]) @pytest.mark.parametrize("GRID_MINOR_DIM", [0, 1]) @pytest.mark.parametrize("GRID_TILE_WIDTH", [8]) @pytest.mark.parametrize("CGA_LAYOUT", [(), ((1, 0), ), ((1, 0), (2, 0))]) @pytest.mark.parametrize("STAGES", [2, 4]) @pytest.mark.parametrize("ACC_STAGES", [2]) @pytest.mark.parametrize("EPILOGUE_SIZE_N", [32]) @pytest.mark.parametrize("SUBTILE_STAGES", [4]) @pytest.mark.parametrize("M, N, K", [(100, 200, 200)]) def test_matmul_matches_torch( M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, GRID_MINOR_DIM, GRID_TILE_WIDTH, CGA_LAYOUT, STAGES, ACC_STAGES, EPILOGUE_SIZE_N, SUBTILE_STAGES, ): # To support epilogue splitting we need to be able to split within a CTA EPILOGUE_SIZE_N = get_epilogue_size_n(BLOCK_SIZE_M, BLOCK_SIZE_N, CGA_LAYOUT) torch.manual_seed(0) a = torch.rand((M, K), device=torch.device("cuda"), dtype=torch.float16) b = torch.rand((K, N), device=torch.device("cuda"), dtype=torch.float16) expected = torch.matmul(a, b) try: actual = matmul_with_config( a, b, block_size_m=BLOCK_SIZE_M, block_size_n=BLOCK_SIZE_N, block_size_k=BLOCK_SIZE_K, grid_minor_dim=GRID_MINOR_DIM, grid_tile_width=GRID_TILE_WIDTH, stages=STAGES, acc_stages=ACC_STAGES, cga_layout=CGA_LAYOUT, epilogue_size_n=EPILOGUE_SIZE_N, subtile_stages=SUBTILE_STAGES, ) except triton.OutOfResources: pytest.skip("Out of resources") torch.testing.assert_close(expected, actual, atol=1e-1, rtol=1e-2) ######################################################## # Benchmarking ######################################################## def show_profile(profile_name): import triton.profiler.viewer as proton_viewer metric_names = ["tflop16/s", "time/ms"] file_name = f"{profile_name}.hatchet" tree, metrics = proton_viewer.parse(metric_names, file_name) proton_viewer.print_tree(tree, metrics) def print_benchmark_header(): print("=" * 60) print("Gluon Matmul Benchmark") print("=" * 60) props = torch.cuda.get_device_properties(0) print(f"Device: {props.name}, SMs: {props.multi_processor_count}") def create_benchmark_tensors(): M, N, K = 4096, 8192, 4096 print(f"Matrix: M={M}, N={N}, K={K}") a = torch.randn((M, K), device="cuda", dtype=torch.float16) b = torch.randn((K, N), device="cuda", dtype=torch.float16) c_triton = torch.empty((M, N), device="cuda", dtype=torch.float16) c_torch = torch.empty((M, N), device="cuda", dtype=torch.float16) expected = torch.matmul(a, b) return (M, N, K), a, b, c_triton, c_torch, expected def get_benchmark_kernel_config(): return { "tile_m": 128, "tile_n": 256, "tile_k": 64, "grid_minor_dim": 0, "grid_tile_width": 16, "stages": 6, "acc_stages": 2, "cga_layout": ((1, 0), ), "epilogue_tile_n": 32, "subtile_stages": 4, } def make_gluon_runner(a, b, c_triton, cfg, use_autotuned=False): if use_autotuned: def run_gluon(): return matmul(a, b) return run_gluon def run_gluon(): return matmul_with_config( a, b, out=c_triton, block_size_m=cfg["tile_m"], block_size_n=cfg["tile_n"], block_size_k=cfg["tile_k"], grid_minor_dim=cfg["grid_minor_dim"], grid_tile_width=cfg["grid_tile_width"], stages=cfg["stages"], acc_stages=cfg["acc_stages"], cga_layout=cfg["cga_layout"], epilogue_size_n=cfg["epilogue_tile_n"], subtile_stages=cfg["subtile_stages"], ) return run_gluon def run_profile(shape, a, b, c_torch, run_gluon): import triton.profiler as proton M, N, K = shape proton.start("matmul", hook="triton") proton.deactivate(0) l2_cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() def bench_fn(label, reps, fn): print(f"Benchmarking {label}: ...", end="") proton.deactivate() for _ in range(5): fn() # warmup triton.runtime.driver.active.clear_cache(l2_cache) try: for _ in range(reps): proton.deactivate() triton.runtime.driver.active.clear_cache(l2_cache) proton.activate() fn() finally: proton.deactivate() print(f"\rBenchmarking {label}: done") bytes_per_elem = a.element_size() scope_metrics = { "bytes": bytes_per_elem * (M * K + N * K + M * N), f"flops{bytes_per_elem * 8}": 2.0 * M * N * K, } def torch_profiled(): with proton.scope(f"torch [M={M}, N={N}, K={K}]", scope_metrics): torch.matmul(a, b, out=c_torch) def gluon_profiled(): with proton.scope(f"gluon [M={M}, N={N}, K={K}]", scope_metrics): run_gluon() bench_fn("torch", reps=100, fn=torch_profiled) bench_fn("gluon", reps=100, fn=gluon_profiled) proton.finalize() print("Proton profile written to `matmul.hatchet`") show_profile("matmul") def benchmark(*, profile=True, use_autotuned=False): if not is_blackwell(): raise RuntimeError("This benchmark requires a Blackwell CUDA GPU.") print_benchmark_header() shape, a, b, c_triton, c_torch, expected = create_benchmark_tensors() kernel_cfg = get_benchmark_kernel_config() runner_name = "matmul (autotuned)" if use_autotuned else "matmul_with_config" print(f"Gluon runner: {runner_name}") run_gluon = make_gluon_runner(a, b, c_triton, kernel_cfg, use_autotuned=use_autotuned) actual = run_gluon() torch.testing.assert_close(actual, expected, atol=1e-1, rtol=1e-2) if use_autotuned: print(f"Autotuned best config: {matmul_kernel.best_config}") if not profile: print("Skipping profiling (--no-profile).") return run_profile(shape, a, b, c_torch, run_gluon) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Gluon matmul benchmark") """ To enable NCU profiling, run the script with the following example command: ``` ncu --target-processes all \ --set full \ --import-source yes \ --kernel-name-base function \ --kernel-name 'regex:.*_matmul_kernel.*' \ --launch-count 1 \ -o ncu_triton_matmul \ python 02-matmul-multicta.py --no-profile ``` """ parser.add_argument( "--no-profile", action="store_true", help="Skip Proton profiling and exit after validation.", ) parser.add_argument( "--use-autotuned", action="store_true", help="Use autotuned matmul() instead of matmul_with_config() for the Gluon runner.", ) args = parser.parse_args() benchmark(profile=not args.no_profile, use_autotuned=args.use_autotuned) ````