(example_moe-bmm1-fused-gather)= # MOE Bmm1 Fused Gather This example can be found at ``python/examples/gluon/05-moe-bmm1-fused-gather.py``. ```python from dataclasses import dataclass, replace from itertools import chain import pytest import torch import triton import triton.experimental.gluon as gluon import triton.experimental.gluon.language as gl import triton.experimental.gluon.language.nvidia.blackwell as blackwell import triton.experimental.gluon.language.nvidia.blackwell.tma as tma from triton.experimental.gluon.language.nvidia.blackwell import float2 import triton.experimental.gluon.language.nvidia.hopper.mbarrier as mbarrier import triton.language.extra.libdevice as libdevice from triton.testing import do_bench_cudagraph from triton_kernels.distributed import make_expt_dict_uniform from triton_kernels.matmul import ( FlexCtx, FnSpecs, FusedActivation, PrecisionConfig, matmul as reference_matmul, ) from triton_kernels.numerics import InFlexData, OutFlexData from triton_kernels.numerics_details.mxfp import MXFP_BLOCK_SIZE, downcast_to_mxfp from triton_kernels.swiglu import swiglu_fn from triton_kernels.tensor import ( FP4, RaggedTensorMetadata, Tensor, convert_layout, make_ragged_tensor_metadata, wrap_torch_tensor, ) from triton_kernels.tensor_details.dtype import UINT8 from triton_kernels.tensor_details.layout import ( BlackwellMX4ValueShuffledLayout, make_default_matmul_mxfp4_w_scale_layout, ) from triton_kernels.testing import alloc_rand, assert_close from triton_kernels.topk import topk # ===-----------------------------------------------------------------------===# # Device Code # ===-----------------------------------------------------------------------===# @gluon.jit def advance(idx: gl.tensor, phase: gl.tensor, num_bufs: gl.constexpr) -> tuple[gl.tensor, gl.tensor]: next_idx = idx + 1 wrap = next_idx == num_bufs return gl.where(wrap, 0, next_idx), gl.where(wrap, phase ^ 1, phase) @gluon.jit def unpack_block_schedule(schedule: gl.tensor) -> tuple[gl.tensor, gl.tensor]: return schedule & 0xFFFF, schedule >> 16 @gluon.jit def banded_row_major(block_id, grid_m, GRID_N: gl.constexpr, BAND_N: gl.constexpr): if BAND_N >= GRID_N: return block_id // GRID_N, block_id % GRID_N full_band_tiles = grid_m * BAND_N n_full_bands = GRID_N // BAND_N full_band_work = n_full_bands * full_band_tiles if block_id < full_band_work: band_id = block_id // full_band_tiles within_band = block_id % full_band_tiles return within_band // BAND_N, band_id * BAND_N + (within_band % BAND_N) tail_n = GRID_N - n_full_bands * BAND_N tail_idx = block_id - full_band_work return tail_idx // tail_n, n_full_bands * BAND_N + (tail_idx % tail_n) @gluon.jit def apply_block_schedule( block_id: gl.tensor, grid_m: gl.tensor, GRID_N: gl.constexpr, slice_offsets: gl.tensor, block_schedule: gl.tensor, BAND_N: gl.constexpr, ) -> tuple[gl.tensor, gl.tensor, gl.tensor, gl.tensor]: schedule_pid_m, pid_n = banded_row_major(block_id, grid_m, GRID_N, BAND_N=BAND_N) slice_idx, pid_m = unpack_block_schedule(gl.load(block_schedule + schedule_pid_m)) slice_offset = gl.load(slice_offsets + slice_idx) return pid_m, pid_n, slice_idx, slice_offset @gluon.jit def unswizzle_mx_scale( smem, SIZE_OUTER: gl.constexpr, SIZE_INNER: gl.constexpr, MXFP_BLOCK_SIZE: gl.constexpr, ): rows: gl.constexpr = smem.shape[1] cols: gl.constexpr = smem.shape[2] * smem.shape[3] * smem.shape[4] tiles: gl.constexpr = cols // (SIZE_OUTER * SIZE_INNER) smem = smem.reshape((rows, tiles, MXFP_BLOCK_SIZE, SIZE_OUTER // MXFP_BLOCK_SIZE, SIZE_INNER)) smem = smem.permute((0, 3, 2, 1, 4)) return smem.reshape((rows * SIZE_OUTER, cols // SIZE_OUTER)) @gluon.jit def alloc_barrier_ring(num_bufs: gl.constexpr, two_ctas: gl.constexpr = False): bars = mbarrier.allocate_mbarrier(batch=num_bufs, two_ctas=two_ctas) for i in gl.static_range(num_bufs): mbarrier.init(bars.index(i), count=1) return bars @gluon.jit def alloc_ring_barriers( num_bufs: gl.constexpr, producer_two_ctas: gl.constexpr = False, consumer_two_ctas: gl.constexpr = False, ): return ( alloc_barrier_ring(num_bufs, two_ctas=producer_two_ctas), alloc_barrier_ring(num_bufs, two_ctas=consumer_two_ctas), ) @gluon.jit def pack_e4m3x2(values): return gl.inline_asm_elementwise( """ { .reg .f32 lane<2>; mov.b64 {lane0, lane1}, $1; cvt.rn.satfinite.e4m3x2.f32 $0, lane1, lane0; } """, "=h,l", [values.value], dtype=gl.int16, is_pure=True, pack=1, ) @gluon.jit def pack_u16x2(x0, x1): return gl.inline_asm_elementwise( """ mov.b32 $0, { $1, $2 }; """, "=r,h,h", [x0, x1], dtype=gl.int32, is_pure=True, pack=1, ) @gluon.jit def pack_fp8x4(values): lhs, rhs = gl.split(values.reshape((values.shape[0], values.shape[1] // 2, 2))) return pack_u16x2(lhs, rhs) @gluon.jit def _split_m(values): return gl.split(values.reshape((2, values.shape[0] // 2, values.shape[1])).permute((1, 2, 0))) @gluon.jit def _split_m_float2(values): lhs, rhs = _split_m(values.value) return float2.Float2Tensor(lhs), float2.Float2Tensor(rhs) @gluon.jit def split_m_subtiles(values, subtile_factor: gl.constexpr): # For epilogue subtiling. subtiles = (values, ) for split_level in gl.static_range(5): if (1 << split_level) < subtile_factor: next_subtiles = () for subtile_idx in gl.static_range(1 << split_level): lhs, rhs = _split_m_float2(subtiles[subtile_idx]) next_subtiles += (lhs, rhs) subtiles = next_subtiles return subtiles @gluon.aggregate class PartitionArgs: x_desc: tma.tensor_descriptor w_desc: tma.tensor_descriptor scale_desc: tma.tensor_descriptor out_desc: tma.tensor_descriptor x_scale_ptr: gl.tensor | gl.constexpr w_scale_ptr: gl.tensor | gl.constexpr out_scale_ptr: gl.tensor out_ptr: gl.tensor bias_ptr: gl.tensor bias_stride: gl.tensor gather_indx_ptr: gl.tensor x_slice_sizes: gl.tensor x_slice_offs: gl.tensor x_block_schedule: gl.tensor x_bufs: gl.shared_memory_descriptor x_empty_bars: gl.shared_memory_descriptor x_ready_bars: gl.shared_memory_descriptor x_num_bufs: gl.constexpr w_bufs: gl.shared_memory_descriptor w_scale_bufs: gl.shared_memory_descriptor w_empty_bars: gl.shared_memory_descriptor w_ready_bars: gl.shared_memory_descriptor w_num_bufs: gl.constexpr x_scale_tmem: blackwell.tensor_memory_descriptor w_scale_tmem: blackwell.tensor_memory_descriptor acc_bufs: blackwell.tensor_memory_descriptor acc_empty_bars: gl.shared_memory_descriptor acc_ready_bars: gl.shared_memory_descriptor acc_num_bufs: gl.constexpr grid_m: gl.tensor GRID_N: gl.constexpr K_TILES: gl.constexpr SCALE_FLAT_N: gl.constexpr SCALE_BLOCK_N_DIV: gl.constexpr num_blocks: gl.tensor NUM_SMS: gl.constexpr USE_2CTA: gl.constexpr BLOCK_M_PER_CTA: gl.constexpr BLOCK_M: gl.constexpr BLOCK_N: gl.constexpr BLOCK_K: gl.constexpr SCALE_SIZE_OUTER: gl.constexpr SCALE_SIZE_INNER: gl.constexpr MXFP_BLOCK_SIZE: gl.constexpr SWIGLU_ALPHA: gl.constexpr SWIGLU_LIMIT: gl.constexpr REDUCTION_N: gl.constexpr FLEXPOINT_SATURATE_INF: gl.constexpr SWIGLU_SUBTILE_FACTOR: gl.constexpr BAND_N: gl.constexpr X_GATHER_MULTICAST: gl.constexpr W_SCALE_MULTICAST: gl.constexpr FORCE_EPILOGUE_WARPS_N1: gl.constexpr @gluon.jit def apply_block_schedule(self, block_id: gl.tensor) -> tuple[gl.tensor, gl.tensor, gl.tensor, gl.tensor]: return apply_block_schedule( block_id=block_id, grid_m=self.grid_m, GRID_N=self.GRID_N, slice_offsets=self.x_slice_offs, block_schedule=self.x_block_schedule, BAND_N=self.BAND_N, ) @gluon.jit def load_activations(p: PartitionArgs): local_cga_layout: gl.constexpr = ((0, 1), ) if p.USE_2CTA else () offs_layout: gl.constexpr = gl.SliceLayout( dim=0, parent=gl.BlockedLayout([1, 4], [32, 1], [1, gl.num_warps()], [1, 0], cga_layout=local_cga_layout), ) tile_x_bytes: gl.constexpr = p.x_desc.block_type.nbytes * (p.BLOCK_M_PER_CTA if p.USE_2CTA else p.BLOCK_M) idx = 0 phase = 1 issued = 0 for block_id in range(gl.program_id(0), p.num_blocks, p.NUM_SMS): pid_m, _, slice_idx, slice_offset = p.apply_block_schedule(block_id) off_m = pid_m * p.BLOCK_M shape_m = gl.load(p.x_slice_sizes + slice_idx) offs_m = off_m + gl.arange(0, p.BLOCK_M, layout=offs_layout) mask_m = offs_m < shape_m offs_x_m = gl.load( p.gather_indx_ptr + slice_offset + offs_m, mask=mask_m, other=p.x_desc.shape[0], ) for ki in range(p.K_TILES): off_k_x = ki * p.BLOCK_K empty_bar = p.x_empty_bars.index(idx) ready_bar = p.x_ready_bars.index(idx) x_buf = p.x_bufs.index(idx) mbarrier.wait(empty_bar, phase, pred=issued >= p.x_num_bufs) mbarrier.expect(ready_bar, tile_x_bytes) tma.async_gather( p.x_desc, offs_x_m, off_k_x, ready_bar, x_buf, multicast=p.USE_2CTA and p.X_GATHER_MULTICAST, ) idx, phase = advance(idx, phase, p.x_num_bufs) issued += 1 @gluon.jit def load_weights(p: PartitionArgs): tile_w_bytes: gl.constexpr = p.w_desc.nbytes_per_cta tile_scale_bytes: gl.constexpr = p.scale_desc.nbytes_per_cta bytes_per_stage: gl.constexpr = tile_w_bytes + tile_scale_bytes idx = 0 phase = 1 issued = 0 for block_id in range(gl.program_id(0), p.num_blocks, p.NUM_SMS): _, pid_n, slice_idx, _ = p.apply_block_schedule(block_id) scale_idx = slice_idx * p.SCALE_FLAT_N + pid_n * p.SCALE_BLOCK_N_DIV for ki in range(p.K_TILES): off_k_scale = ki * p.BLOCK_K // (p.MXFP_BLOCK_SIZE * p.SCALE_SIZE_INNER) w_empty_bar = p.w_empty_bars.index(idx) w_ready_bar = p.w_ready_bars.index(idx) w_buf = p.w_bufs.index(idx) scale_buf = p.w_scale_bufs.index(idx) mbarrier.wait(w_empty_bar, phase, pred=issued >= p.w_num_bufs) mbarrier.expect(w_ready_bar, bytes_per_stage) tma.async_copy_global_to_shared(p.w_desc, [slice_idx, ki, pid_n, 0, 0], w_ready_bar, w_buf) tma.async_copy_global_to_shared( p.scale_desc, [0, scale_idx, off_k_scale, 0, 0], w_ready_bar, scale_buf, multicast=p.USE_2CTA and p.W_SCALE_MULTICAST, ) idx, phase = advance(idx, phase, p.w_num_bufs) issued += 1 @gluon.jit def mma_partition(p: PartitionArgs): x_idx = 0 x_phase = 0 w_idx = 0 w_phase = 0 mma_idx = 0 mma_phase = 1 for block_id in range(gl.program_id(0), p.num_blocks, p.NUM_SMS): acc_empty_bar = p.acc_empty_bars.index(mma_idx) acc_ready_bar = p.acc_ready_bars.index(mma_idx) acc_buf = p.acc_bufs.index(mma_idx) mbarrier.wait(acc_empty_bar, mma_phase) use_acc = False for _ in range(p.K_TILES): w_ready_bar = p.w_ready_bars.index(w_idx) w_empty_bar = p.w_empty_bars.index(w_idx) w_buf = p.w_bufs.index(w_idx) scale_buf = p.w_scale_bufs.index(w_idx) mbarrier.wait(w_ready_bar, w_phase) blackwell.tcgen05_copy( unswizzle_mx_scale(scale_buf, p.SCALE_SIZE_OUTER, p.SCALE_SIZE_INNER, p.MXFP_BLOCK_SIZE), p.w_scale_tmem, ) x_ready_bar = p.x_ready_bars.index(x_idx) x_empty_bar = p.x_empty_bars.index(x_idx) x_buf = p.x_bufs.index(x_idx) mbarrier.wait(x_ready_bar, x_phase) blackwell.tcgen05_mma_scaled( w_buf.reshape((p.BLOCK_N, p.BLOCK_K // 2)), x_buf.permute((1, 0)), acc_buf, p.w_scale_tmem, p.x_scale_tmem, a_type="e2m1", b_type="e4m3", use_acc=use_acc, ) blackwell.tcgen05_commit(x_empty_bar) blackwell.tcgen05_commit(w_empty_bar) x_idx, x_phase = advance(x_idx, x_phase, p.x_num_bufs) w_idx, w_phase = advance(w_idx, w_phase, p.w_num_bufs) use_acc = True blackwell.tcgen05_commit(acc_ready_bar) mma_idx, mma_phase = advance(mma_idx, mma_phase, p.acc_num_bufs) @gluon.jit def store_packed_out( p: PartitionArgs, packed_out, off_m, out_off_n_packed, shape_m, slice_offset, ): values = pack_fp8x4(packed_out) layout: gl.constexpr = values.type.layout offs_m = off_m + gl.arange(0, values.shape[0], layout=gl.SliceLayout(1, layout)) offs_n = out_off_n_packed + gl.arange(0, values.shape[1], layout=gl.SliceLayout(0, layout)) mask_m = gl.expand_dims(offs_m < shape_m, 1) mask_n = gl.expand_dims(offs_n < (p.out_desc.shape[1] + 3) // 4, 0) mask = mask_m & mask_n ptrs = p.out_ptr.cast(gl.pointer_type(gl.int32), bitcast=True) ptrs = ptrs + gl.expand_dims(slice_offset + offs_m, 1) * (p.out_desc.strides[0] // 4) ptrs = ptrs + gl.expand_dims(offs_n, 0) * p.out_desc.strides[1] gl.store(ptrs, values, mask=mask) @gluon.jit def _swiglu_step1(acc_packed, limit): gelu, linear = float2.unpack2(acc_packed) gelu = gl.minimum(gelu.to(gl.float32), limit) linear = gl.minimum(gl.maximum(linear.to(gl.float32), -limit), limit) return gelu, linear @gluon.jit def _swiglu_step2(gelu, linear, alpha): den = 1.0 + libdevice.exp(-alpha * gelu) activated = gelu / den activated_packed = float2.pack(activated, axis=1) linear_packed = float2.pack(linear, axis=1) return float2.fma(activated_packed, linear_packed, activated_packed) @gluon.jit def pack_fp8_out_fragment(out_packed, out_recip): scaled_out_packed = out_packed * float2.full_like(out_packed, out_recip) return pack_e4m3x2(scaled_out_packed) @gluon.jit def get_store_layout(p: PartitionArgs): frag_rows: gl.constexpr = p.BLOCK_M // p.SWIGLU_SUBTILE_FACTOR local_cga_layout: gl.constexpr = ((0, 1), ) if p.USE_2CTA else () return gl.BlockedLayout( [frag_rows // gl.num_warps(), 2], [1, 32], [gl.num_warps(), 1], [1, 0], cga_layout=local_cga_layout, ) @gluon.jit def epilogue_direct_store( p: PartitionArgs, acc_packed, out_recip, off_m, out_off_n_packed, shape_m, slice_offset, store_layout: gl.constexpr, ): frag_rows: gl.constexpr = p.BLOCK_M // p.SWIGLU_SUBTILE_FACTOR acc_packed_subtiles = split_m_subtiles(acc_packed, p.SWIGLU_SUBTILE_FACTOR) for frag_idx in gl.static_range(p.SWIGLU_SUBTILE_FACTOR): gelu, linear = _swiglu_step1(acc_packed_subtiles[frag_idx], p.SWIGLU_LIMIT) out_packed = _swiglu_step2(gelu, linear, p.SWIGLU_ALPHA) packed_fp8 = gl.convert_layout(pack_fp8_out_fragment(out_packed, out_recip), store_layout) store_packed_out( p, packed_fp8, off_m + frag_idx * frag_rows, out_off_n_packed, shape_m, slice_offset, ) @gluon.jit def apply_bias_and_scale( p: PartitionArgs, idx, phase, pid_n, slice_idx, split_layout: gl.constexpr, bias_layout: gl.constexpr, acc_scale, ): off_n = pid_n * p.BLOCK_N acc_empty_bar = p.acc_empty_bars.index(idx) acc_ready_bar = p.acc_ready_bars.index(idx) acc_buf = p.acc_bufs.index(idx) offs_bias_n = off_n + gl.arange(0, p.BLOCK_N, layout=bias_layout) bias = gl.convert_layout( gl.expand_dims(gl.load(p.bias_ptr + slice_idx * p.bias_stride + offs_bias_n), axis=0), split_layout, ) mbarrier.wait(acc_ready_bar, phase) acc_regs = acc_buf.load().permute((1, 0)) mbarrier.arrive(acc_empty_bar) idx, phase = advance(idx, phase, p.acc_num_bufs) acc = gl.convert_layout(acc_regs, split_layout) acc_packed = float2.pack(acc, axis=1) bias_packed = float2.pack(bias, axis=1) bias_packed = float2.Float2Tensor(gl.convert_layout(bias_packed.value, acc_packed.value.type.layout)) acc_packed = float2.fma(acc_packed, float2.full_like(acc_packed, acc_scale), bias_packed) return idx, phase, acc_packed @gluon.jit def epilogue_partition(p: PartitionArgs): idx = 0 phase = 0 x_scale = 1.0 if p.x_scale_ptr is None else gl.load(p.x_scale_ptr) w_scale = 1.0 if p.w_scale_ptr is None else gl.load(p.w_scale_ptr) acc_scale = x_scale * w_scale out_recip = 1.0 / gl.load(p.out_scale_ptr) num_warps: gl.constexpr = gl.num_warps() warps_n: gl.constexpr = 1 if p.FORCE_EPILOGUE_WARPS_N1 else (2 if num_warps >= 8 and p.BLOCK_N >= 256 else 1) split_cga_layout: gl.constexpr = ((0, 1), ) if p.USE_2CTA else () split_layout: gl.constexpr = gl.BlockedLayout( [1, 4], [1, 32], [num_warps // warps_n, warps_n], [1, 0], cga_layout=split_cga_layout, ) bias_layout: gl.constexpr = gl.SliceLayout(0, split_layout) store_layout: gl.constexpr = get_store_layout(p) for block_id in range(gl.program_id(0), p.num_blocks, p.NUM_SMS): pid_m, pid_n, slice_idx, slice_offset = p.apply_block_schedule(block_id) off_m = pid_m * p.BLOCK_M shape_m = gl.load(p.x_slice_sizes + slice_idx) out_off_n_packed = pid_n * (p.BLOCK_N // p.REDUCTION_N // 4) idx, phase, acc_packed = apply_bias_and_scale( p, idx, phase, pid_n, slice_idx, split_layout, bias_layout, acc_scale, ) epilogue_direct_store( p, acc_packed, out_recip, off_m, out_off_n_packed, shape_m, slice_offset, store_layout, ) @gluon.jit def ws_matmul_kernel( x_desc: tma.tensor_descriptor, w_desc: tma.tensor_descriptor, scale_desc: tma.tensor_descriptor, out_desc: tma.tensor_descriptor, out_ptr: gl.tensor, # bias_ptr: gl.tensor, bias_stride: gl.tensor, # gather_indx_ptr: gl.tensor, # x_slice_sizes: gl.tensor, x_slice_offs: gl.tensor, x_block_offs: gl.tensor, x_block_schedule: gl.tensor, # x_scale_ptr: gl.tensor, w_scale_ptr: gl.tensor, out_scale_ptr: gl.tensor, # M: gl.constexpr, N: gl.constexpr, K: gl.constexpr, NUM_SLICES: gl.constexpr, # SWIGLU_ALPHA: gl.constexpr, SWIGLU_LIMIT: gl.constexpr, REDUCTION_N: gl.constexpr, # FLEXPOINT_SATURATE_INF: gl.constexpr, # BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr, BLOCK_K: gl.constexpr, NUM_SMS: gl.constexpr, X_NUM_BUFS: gl.constexpr, W_NUM_BUFS: gl.constexpr, ACC_NUM_BUFS: gl.constexpr, LOAD_ACTIVATION_WARPS: gl.constexpr, LOAD_WEIGHT_WARPS: gl.constexpr, MMA_WARPS: gl.constexpr, LOAD_ACTIVATION_REGS: gl.constexpr, LOAD_WEIGHT_REGS: gl.constexpr, MMA_REGS: gl.constexpr, SWIGLU_SUBTILE_FACTOR: gl.constexpr, BAND_N: gl.constexpr, X_GATHER_MULTICAST: gl.constexpr, W_SCALE_MULTICAST: gl.constexpr, FORCE_EPILOGUE_WARPS_N1: gl.constexpr, SCALE_SIZE_OUTER: gl.constexpr, SCALE_SIZE_INNER: gl.constexpr, MXFP_BLOCK_SIZE: gl.constexpr, ): use_2cta: gl.constexpr = gl.num_ctas() > 1 gl.static_assert(gl.num_ctas() == 1 or gl.num_ctas() == 2, "kernel supports at most 2 CTAs") gl.static_assert(not use_2cta or BLOCK_N >= 256, "2CTA path requires at least 128 columns per CTA") grid_m = gl.load(x_block_offs + NUM_SLICES) grid_n: gl.constexpr = triton.cdiv(N, BLOCK_N) k_tiles: gl.constexpr = triton.cdiv(K, BLOCK_K) scale_flat_n: gl.constexpr = N // SCALE_SIZE_OUTER scale_block_n_div: gl.constexpr = BLOCK_N // SCALE_SIZE_OUTER num_blocks = grid_m * grid_n scale_k: gl.constexpr = BLOCK_K // MXFP_BLOCK_SIZE block_m_per_cta: gl.constexpr = BLOCK_M // gl.num_ctas() x_scale_layout: gl.constexpr = blackwell.TensorMemoryScalesLayout(cga_layout=((0, 0), ) if use_2cta else (), ) w_scale_layout: gl.constexpr = blackwell.TensorMemoryScalesLayout(cga_layout=((1, 0), ) if use_2cta else (), ) mma_block_col: gl.constexpr = min(128, BLOCK_N // gl.num_ctas()) acc_layout: gl.constexpr = blackwell.TensorMemoryLayout( [mma_block_col, BLOCK_M], col_stride=1, cga_layout=((1, 0), ) if use_2cta else (), two_ctas=use_2cta, ) x_num_bufs: gl.constexpr = X_NUM_BUFS x_bufs = gl.allocate_shared_memory( x_desc.dtype, [x_num_bufs, BLOCK_M, x_desc.block_type.shape[1]], x_desc.layout, ) x_empty_bars, x_ready_bars = alloc_ring_barriers(x_num_bufs, consumer_two_ctas=use_2cta) w_num_bufs: gl.constexpr = W_NUM_BUFS w_bufs = gl.allocate_shared_memory( w_desc.dtype, [w_num_bufs] + w_desc.block_type.shape, w_desc.layout, ) w_scale_bufs = gl.allocate_shared_memory( scale_desc.dtype, [w_num_bufs] + scale_desc.block_type.shape, scale_desc.layout, ) w_empty_bars, w_ready_bars = alloc_ring_barriers(w_num_bufs, consumer_two_ctas=use_2cta) x_scale_tmem = blackwell.allocate_tensor_memory(gl.uint8, [BLOCK_M, scale_k], x_scale_layout) w_scale_tmem = blackwell.allocate_tensor_memory(gl.uint8, [BLOCK_N, scale_k], w_scale_layout) acc_num_bufs: gl.constexpr = ACC_NUM_BUFS acc_tmem = blackwell.allocate_tensor_memory( gl.float32, [acc_num_bufs, BLOCK_N, BLOCK_M], acc_layout, ) acc_empty_bars, acc_ready_bars = alloc_ring_barriers(acc_num_bufs, producer_two_ctas=use_2cta) x_scale_tmem.store(gl.full((BLOCK_M, scale_k), 127, dtype=gl.uint8, layout=x_scale_tmem.get_reg_layout())) p = PartitionArgs( x_desc=x_desc, w_desc=w_desc, scale_desc=scale_desc, out_desc=out_desc, x_scale_ptr=x_scale_ptr, w_scale_ptr=w_scale_ptr, out_scale_ptr=out_scale_ptr, # out_ptr=out_ptr, bias_ptr=bias_ptr, bias_stride=bias_stride, gather_indx_ptr=gather_indx_ptr, x_slice_sizes=x_slice_sizes, x_slice_offs=x_slice_offs, x_block_schedule=x_block_schedule, # x_bufs=x_bufs, x_empty_bars=x_empty_bars, x_ready_bars=x_ready_bars, x_num_bufs=x_num_bufs, # w_bufs=w_bufs, w_scale_bufs=w_scale_bufs, w_empty_bars=w_empty_bars, w_ready_bars=w_ready_bars, w_num_bufs=w_num_bufs, # x_scale_tmem=x_scale_tmem, w_scale_tmem=w_scale_tmem, acc_bufs=acc_tmem, acc_empty_bars=acc_empty_bars, acc_ready_bars=acc_ready_bars, acc_num_bufs=acc_num_bufs, # grid_m=grid_m, GRID_N=grid_n, K_TILES=k_tiles, SCALE_FLAT_N=scale_flat_n, SCALE_BLOCK_N_DIV=scale_block_n_div, num_blocks=num_blocks, # NUM_SMS=NUM_SMS, USE_2CTA=use_2cta, BLOCK_M_PER_CTA=block_m_per_cta, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, SCALE_SIZE_OUTER=SCALE_SIZE_OUTER, SCALE_SIZE_INNER=SCALE_SIZE_INNER, MXFP_BLOCK_SIZE=MXFP_BLOCK_SIZE, # SWIGLU_ALPHA=SWIGLU_ALPHA, SWIGLU_LIMIT=SWIGLU_LIMIT, REDUCTION_N=REDUCTION_N, FLEXPOINT_SATURATE_INF=FLEXPOINT_SATURATE_INF, # SWIGLU_SUBTILE_FACTOR=SWIGLU_SUBTILE_FACTOR, BAND_N=BAND_N, X_GATHER_MULTICAST=X_GATHER_MULTICAST, W_SCALE_MULTICAST=W_SCALE_MULTICAST, FORCE_EPILOGUE_WARPS_N1=FORCE_EPILOGUE_WARPS_N1, ) gl.warp_specialize( [ (epilogue_partition, (p, )), (load_activations, (p, )), (load_weights, (p, )), (mma_partition, (p, )), ], [LOAD_ACTIVATION_WARPS, LOAD_WEIGHT_WARPS, MMA_WARPS], [LOAD_ACTIVATION_REGS, LOAD_WEIGHT_REGS, MMA_REGS], ) # ===-----------------------------------------------------------------------===# # Host Code # ===-----------------------------------------------------------------------===# def make_tensor_descriptor( t: torch.Tensor | Tensor, block_shape: tuple[int, ...], *, layout_block_shape: tuple[int, ...] | None = None, cga_layout: tuple[tuple[int, ...], ...] = (), ): from triton.experimental.gluon.nvidia.hopper import TensorDescriptor ptr = t if isinstance(t, torch.Tensor) else t.storage.data shape = list(ptr.shape) strides = list(ptr.stride()) desc_block_shape = list(block_shape) layout_shape = list(layout_block_shape or block_shape) if isinstance(t, Tensor) and t.dtype == FP4: assert isinstance(t.storage.layout, BlackwellMX4ValueShuffledLayout) assert layout_block_shape is None desc_block_shape = t.storage.layout.swizzle_block_shape(desc_block_shape) desc_block_shape[strides.index(1)] //= 2 layout_shape = desc_block_shape rank = len(layout_shape) if t.dtype == FP4: assert rank == 5 layout = gl.NVMMASharedLayout( swizzle_byte_width=128, element_bitwidth=8, rank=rank, fp4_padded=True, cga_layout=cga_layout, ) elif t.dtype == UINT8: assert rank == 5 layout = gl.NVMMASharedLayout( swizzle_byte_width=0, element_bitwidth=8, rank=rank, cga_layout=cga_layout, ) elif t.dtype == torch.float32: assert rank == 2 layout = gl.NVMMASharedLayout.get_default_for( layout_shape, torch.float32, cga_layout=cga_layout, ) else: assert t.dtype == torch.float8_e4m3fn layout = gl.NVMMASharedLayout( swizzle_byte_width=layout_shape[-1], element_bitwidth=8, rank=rank, cga_layout=cga_layout, ) return TensorDescriptor(ptr, shape, strides, desc_block_shape, layout) @dataclass(frozen=True, slots=True) class KernelConfig: BLOCK_M: int = 128 BLOCK_N: int = 256 BLOCK_K: int = 128 NUM_CTAS: int = 1 X_NUM_BUFS: int = 5 W_NUM_BUFS: int = 4 ACC_NUM_BUFS: int = 1 NUM_WARPS: int = 8 LOAD_ACTIVATION_WARPS: int = 4 LOAD_WEIGHT_WARPS: int = 1 MMA_WARPS: int = 1 SWIGLU_SUBTILE_FACTOR: int = 8 BAND_N: int = 20 X_GATHER_MULTICAST: bool = True W_SCALE_MULTICAST: bool = True FORCE_EPILOGUE_WARPS_N1: bool = False LOAD_ACTIVATION_REGS: int = 112 LOAD_WEIGHT_REGS: int = 48 MMA_REGS: int = 48 MAXNREG: int = None OCCUPANCY: int = 1 MXFP_BLOCK_SIZE: int = 32 SCALE_SIZE_OUTER: int = 128 SCALE_SIZE_INNER: int = 4 def get_x_tile_smem(self) -> int: return self.BLOCK_M * self.BLOCK_K def get_w_tile_smem(self) -> int: return self.BLOCK_N * self.BLOCK_K def get_w_mx_tile_smem(self) -> int: return self.get_w_tile_smem() // self.MXFP_BLOCK_SIZE def get_c_tile_smem(self, reduction_n: int) -> int: return (self.BLOCK_M // self.SWIGLU_SUBTILE_FACTOR) * (self.BLOCK_N // reduction_n) def _select_base_config(slice_size: int) -> KernelConfig: if slice_size <= 14: block_m = 16 elif slice_size <= 32: block_m = 32 elif slice_size <= 64: block_m = 64 else: block_m = 128 if slice_size <= 64: x_num_bufs, w_num_bufs = { 16: (10, 5), 32: (5, 5), 64: (4, 4), }[block_m] return KernelConfig( BLOCK_M=block_m, BLOCK_N=128, X_NUM_BUFS=x_num_bufs, W_NUM_BUFS=w_num_bufs, SWIGLU_SUBTILE_FACTOR=min(8, block_m // 8), OCCUPANCY=2, MAXNREG=64, LOAD_ACTIVATION_REGS=48, LOAD_WEIGHT_REGS=32, MMA_REGS=32, ) return KernelConfig( BLOCK_M=block_m, SWIGLU_SUBTILE_FACTOR=min(8, block_m // 8), ) def select_kernel_config(slice_size: int) -> KernelConfig: p = _select_base_config(slice_size) if p.BLOCK_M == 32 and p.BLOCK_N == 128 and slice_size in (16, 20, 24, 32): p = replace( p, BLOCK_N=256, NUM_CTAS=2, NUM_WARPS=4, W_NUM_BUFS=5, SWIGLU_SUBTILE_FACTOR=1, LOAD_ACTIVATION_WARPS=2, LOAD_WEIGHT_WARPS=1, MMA_WARPS=1, LOAD_ACTIVATION_REGS=40, LOAD_WEIGHT_REGS=32, MMA_REGS=32, MAXNREG=52, BAND_N=32, FORCE_EPILOGUE_WARPS_N1=True, ) if slice_size == 16: p = replace(p, X_NUM_BUFS=5) elif slice_size in (20, 24): p = replace(p, X_NUM_BUFS=6) else: p = replace(p, X_NUM_BUFS=6, X_GATHER_MULTICAST=False, W_SCALE_MULTICAST=False) elif 36 <= slice_size <= 72: p = replace( p, BLOCK_M=64, BLOCK_N=256, NUM_CTAS=2, OCCUPANCY=2, X_NUM_BUFS=6, W_NUM_BUFS=5, SWIGLU_SUBTILE_FACTOR=4, LOAD_ACTIVATION_REGS=64, LOAD_WEIGHT_REGS=48, MMA_REGS=48, MAXNREG=64, ) elif p.BLOCK_M == 128 and p.BLOCK_N == 256 and slice_size >= 80: p = replace(p, BLOCK_N=512, NUM_CTAS=2, W_NUM_BUFS=5) if p.BLOCK_M == 32 and p.BLOCK_N == 256 and p.NUM_CTAS == 2 and slice_size <= 32: return replace(p, BAND_N=32) if p.BLOCK_M == 64 and p.BLOCK_N == 256 and p.NUM_CTAS == 2 and 36 <= slice_size <= 72: return replace(p, BAND_N=26) if slice_size < 32: return replace(p, BAND_N=22) if slice_size < 416: return replace(p, BAND_N=18) return replace(p, BAND_N=26) def matmul( a: torch.Tensor, b: torch.Tensor | Tensor, bias: torch.Tensor, a_ragged_metadata: RaggedTensorMetadata, gather_indx: torch.Tensor, precision_config: PrecisionConfig, c: torch.Tensor, fused_activation: FusedActivation, p: KernelConfig | None = None, ): specs = fused_activation.specs assert specs.name == "swiglu" reduction_n = specs.reduction_n swiglu_alpha, swiglu_limit = fused_activation.fn_args b_mx_scales = precision_config.b_mx_scale out_dtype = precision_config.out_dtype assert out_dtype is not None assert c.ndim == 2 flex_ctx = precision_config.flex_ctx assert a.ndim == 2 _, k = a.shape _, _, n = b.shape m = gather_indx.shape[0] p = p or select_kernel_config(a_ragged_metadata.expected_slice_size) assert isinstance(b, Tensor) assert isinstance(b.storage.layout, BlackwellMX4ValueShuffledLayout) assert b.storage.layout.block_k == p.BLOCK_K assert b.storage.layout.block_n == p.BLOCK_N x_block_offs = a_ragged_metadata.block_offs(p.BLOCK_M) x_block_schedule = a_ragged_metadata.block_schedule(p.BLOCK_M) expected_grid_m = a_ragged_metadata.n_blocks(a_ragged_metadata.n_slices, m, p.BLOCK_M) grid_n = triton.cdiv(n, p.BLOCK_N) sms = torch.cuda.get_device_properties(bias.device).multi_processor_count sms *= p.OCCUPANCY launch_grid = max(1, min(max(1, sms // p.NUM_CTAS), expected_grid_m * grid_n)) grid = (launch_grid, ) if p.NUM_CTAS == 1: acc_cga_layout = () elif p.NUM_CTAS == 2: acc_cga_layout = ((1, 0), ) else: raise ValueError(f"unsupported CTA count: {p.NUM_CTAS}") x_desc = make_tensor_descriptor( a, (1, p.BLOCK_K), layout_block_shape=(p.BLOCK_M, p.BLOCK_K), cga_layout=tuple((basis[0], 0) for basis in acc_cga_layout), ) w_desc = make_tensor_descriptor( b, (1, p.BLOCK_K, p.BLOCK_N), # Sharded weight tiles use the physical [1, 1, 1, N, K/2] MX4 shuffled block layout. cga_layout=tuple((0, 0, 0, basis[0], 0) for basis in acc_cga_layout), ) scale_desc = make_tensor_descriptor( b_mx_scales, ( 1, p.BLOCK_N // p.SCALE_SIZE_OUTER, p.BLOCK_K // p.MXFP_BLOCK_SIZE // p.SCALE_SIZE_INNER, 2, 256, ), # Weight scale tiles use the physical [1, N//128, K//(32*4), 2, 256] layout. cga_layout=tuple((0, basis[0], 0, 0, 0) for basis in acc_cga_layout), ) # The output descriptor is only used for shape/stride metadata during # direct stores, so cap the layout width to a legal FP8 swizzle size. out_desc = make_tensor_descriptor(c, (p.BLOCK_M, min(p.BLOCK_N // reduction_n, 128))) ws_matmul_kernel[grid]( x_desc=x_desc, w_desc=w_desc, scale_desc=scale_desc, out_desc=out_desc, out_ptr=c, # bias_ptr=bias, bias_stride=bias.stride(0), # gather_indx_ptr=gather_indx, # x_slice_sizes=a_ragged_metadata.slice_sizes, x_slice_offs=a_ragged_metadata.slice_offs, x_block_offs=x_block_offs, x_block_schedule=x_block_schedule, # x_scale_ptr=flex_ctx.lhs_data.scale, w_scale_ptr=flex_ctx.rhs_data.scale, out_scale_ptr=flex_ctx.out_data.expected_scale, # M=m, N=n, K=k, NUM_SLICES=a_ragged_metadata.n_slices, # SWIGLU_ALPHA=swiglu_alpha, SWIGLU_LIMIT=swiglu_limit, REDUCTION_N=reduction_n, # FLEXPOINT_SATURATE_INF=precision_config.flexpoint_saturate_inf, # BLOCK_M=p.BLOCK_M, BLOCK_N=p.BLOCK_N, BLOCK_K=p.BLOCK_K, NUM_SMS=launch_grid, X_NUM_BUFS=p.X_NUM_BUFS, W_NUM_BUFS=p.W_NUM_BUFS, ACC_NUM_BUFS=p.ACC_NUM_BUFS, LOAD_ACTIVATION_WARPS=p.LOAD_ACTIVATION_WARPS, LOAD_WEIGHT_WARPS=p.LOAD_WEIGHT_WARPS, MMA_WARPS=p.MMA_WARPS, LOAD_ACTIVATION_REGS=p.LOAD_ACTIVATION_REGS, LOAD_WEIGHT_REGS=p.LOAD_WEIGHT_REGS, MMA_REGS=p.MMA_REGS, SWIGLU_SUBTILE_FACTOR=p.SWIGLU_SUBTILE_FACTOR, BAND_N=p.BAND_N, X_GATHER_MULTICAST=p.X_GATHER_MULTICAST, W_SCALE_MULTICAST=p.W_SCALE_MULTICAST, FORCE_EPILOGUE_WARPS_N1=p.FORCE_EPILOGUE_WARPS_N1, # SCALE_SIZE_OUTER=p.SCALE_SIZE_OUTER, SCALE_SIZE_INNER=p.SCALE_SIZE_INNER, MXFP_BLOCK_SIZE=p.MXFP_BLOCK_SIZE, # num_warps=p.NUM_WARPS, num_ctas=p.NUM_CTAS, maxnreg=p.MAXNREG, ) return c # ===-----------------------------------------------------------------------===# # Benchmark and Testing Helpers # ===-----------------------------------------------------------------------===# @dataclass(frozen=True, slots=True) class MLPConfig: name: str num_experts: int experts_per_token: int num_expert_shards: int hidden_size: int intermediate_size: int def get_batch_sizes(c: MLPConfig) -> tuple[int, ...]: batch_per_expert = tuple(chain.from_iterable(range(2**(2 + k), 2**(3 + k), min(2**k, 32)) for k in range(8))) return tuple(batch_per_expert * c.num_experts // c.experts_per_token for batch_per_expert in batch_per_expert) @dataclass(frozen=True, slots=True) class PreparedCase: batch_size: int local_rank: int x: torch.Tensor w: Tensor w_scale: Tensor bias: torch.Tensor ragged_metadata: RaggedTensorMetadata gather_indx: torch.Tensor fused_activation: FusedActivation x_scale: torch.Tensor y_scale: torch.Tensor out_shape: tuple[int, int] out_dtype: torch.dtype def alloc_randn(shape: tuple[int, ...], dtype: torch.dtype, device: str) -> torch.Tensor: if dtype.itemsize == 1: return alloc_rand(shape, device=device, dtype=dtype) return torch.randn(shape, device=device, dtype=dtype) def alloc_randn_fp4(shape: tuple[int, ...], device: str, p: KernelConfig | None) -> tuple[Tensor, Tensor]: if p is not None: block_k, block_n, num_warps = p.BLOCK_K, p.BLOCK_N, p.NUM_WARPS else: block_k, block_n, num_warps = 128, 256, 8 data = alloc_randn(shape, torch.bfloat16, device) data, scale = downcast_to_mxfp(data, FP4, axis=1) # type: ignore[arg-type] data_layout = BlackwellMX4ValueShuffledLayout(block_k=block_k, block_n=block_n) scale_layout = make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps) data = convert_layout(wrap_torch_tensor(data, dtype=FP4), data_layout) scale = convert_layout(wrap_torch_tensor(scale), scale_layout) return data, scale def make_prod_like_logits( batch_size: int, num_experts: int, experts_per_token: int, device: str, dtype: torch.dtype = torch.float16, *, zipf_alpha: float = 1.10, num_clusters: int = 16, cluster_boost: float = 1.25, gumbel_scale: float = 0.75, batch_hot_experts: int = 4, batch_hot_boost: float = 0.6, ) -> torch.Tensor: # Stable expert popularity: a few hot experts, long tail. ranks = torch.arange(1, num_experts + 1, device=device, dtype=torch.float32) ranked_probs = ranks.pow(-zipf_alpha) ranked_probs /= ranked_probs.sum() # Randomize which expert ids are hot so shard/id layout is not special. perm = torch.randperm(num_experts, device=device) expert_probs = torch.empty_like(ranked_probs) expert_probs[perm] = ranked_probs logits = expert_probs.clamp_min(1e-12).log()[None, :].expand(batch_size, -1).clone() # Token locality: each token belongs to a synthetic topic/cluster with preferred experts. cluster_size = min(num_experts, max(2 * experts_per_token, num_experts // 16)) cluster_experts = torch.stack( [torch.multinomial(expert_probs, cluster_size, replacement=False) for _ in range(num_clusters)]) token_cluster = torch.randint(num_clusters, (batch_size, ), device=device) rows = torch.arange(batch_size, device=device)[:, None] logits[rows, cluster_experts[token_cluster]] += cluster_boost # Batch burstiness: a few experts are hotter for this batch. if batch_hot_experts > 0: hot = torch.multinomial(expert_probs, batch_hot_experts, replacement=False) logits[:, hot] += batch_hot_boost # Gumbel noise makes top-k behave like weighted sampling without replacement. noise = -torch.empty_like(logits).exponential_().log() logits += gumbel_scale * noise return logits.to(dtype) def init_routing_data(c: MLPConfig, batch_size: int, local_rank: int, device: str, uniform_routing: bool) -> tuple[RaggedTensorMetadata, torch.Tensor]: expt_dist = make_expt_dict_uniform(c.num_expert_shards, c.num_experts) if uniform_routing: logits = torch.randn((batch_size, c.num_experts), dtype=torch.float16, device=device) else: logits = make_prod_like_logits(batch_size, c.num_experts, c.experts_per_token, device) sparse_logits = topk(logits, c.experts_per_token, apply_softmax=True) expt_hist = sparse_logits.mask_metadata.col_sum local_expts = expt_dist[local_rank] local_expts_hist = expt_hist[local_expts] ragged_metadata = make_ragged_tensor_metadata(local_expts_hist, batch_size * c.experts_per_token) ragged_metadata.expected_slice_size = batch_size * c.experts_per_token // c.num_experts combine_indx = sparse_logits.mask_metadata.col_sorted_indx gather_indx = torch.div(combine_indx, c.experts_per_token, rounding_mode="trunc") return ragged_metadata, gather_indx def prepare_case(c: MLPConfig, batch_size: int, device: str, seed: int = 0, uniform_routing: bool = False, reference: bool = False) -> PreparedCase: torch.manual_seed(seed) local_rank = int(torch.randint(0, c.num_expert_shards, size=()).item()) k, n = c.hidden_size, c.intermediate_size n_expts_local = c.num_experts // c.num_expert_shards ragged_metadata, gather_indx = init_routing_data(c, batch_size, local_rank, device, uniform_routing) p = None if reference else select_kernel_config(ragged_metadata.expected_slice_size) x = alloc_randn((batch_size, k), dtype=torch.float8_e4m3fn, device=device) w, w_scale = alloc_randn_fp4((n_expts_local, k, n), device=device, p=p) bias = alloc_randn((n_expts_local, n), dtype=torch.float32, device=device) swiglu_alpha = float(torch.rand((), device=device).item()) / 5 + 1.0 swiglu_limit = float(torch.rand((), device=device).item()) / 5 + 1.3 fused_activation = FusedActivation( FnSpecs("swiglu", swiglu_fn, ("alpha", "limit"), reduction_n=2), (swiglu_alpha, swiglu_limit), ) x_scale = (torch.rand((), device=device) + 0.5).reshape(1) y_scale = (torch.rand((), device=device) + 3.5).reshape(1) return PreparedCase( batch_size=batch_size, local_rank=local_rank, x=x, w=w, w_scale=w_scale, bias=bias, ragged_metadata=ragged_metadata, gather_indx=gather_indx, fused_activation=fused_activation, x_scale=x_scale, y_scale=y_scale, out_shape=(batch_size * c.experts_per_token, n // fused_activation.specs.reduction_n), out_dtype=torch.float8_e4m3fn, ) def make_precision_config(prepared: PreparedCase) -> PrecisionConfig: return PrecisionConfig( flexpoint_saturate_inf=True, b_mx_scale=prepared.w_scale, b_microblock_size=MXFP_BLOCK_SIZE.value, out_dtype=prepared.out_dtype, flex_ctx=FlexCtx( lhs_data=InFlexData(dtype=prepared.out_dtype, scale=prepared.x_scale), rhs_data=InFlexData(), out_data=OutFlexData(dtype=prepared.out_dtype, expected_scale=prepared.y_scale), ), ) def make_output_buffer(prepared: PreparedCase) -> torch.Tensor: return torch.zeros(prepared.out_shape, dtype=prepared.out_dtype, device=prepared.x.device) def run_kernel(prepared: PreparedCase, kernel, precision_config: PrecisionConfig, out: torch.Tensor) -> torch.Tensor: return kernel( a=prepared.x, b=prepared.w, bias=prepared.bias, a_ragged_metadata=prepared.ragged_metadata, gather_indx=prepared.gather_indx, precision_config=precision_config, c=out, fused_activation=prepared.fused_activation, ) def run_provider(prepared: PreparedCase, provider: str) -> tuple[torch.Tensor, PrecisionConfig]: precision_config = make_precision_config(prepared) kernel = matmul if provider == "example" else reference_matmul y = run_kernel(prepared, kernel, precision_config, make_output_buffer(prepared)) return y, precision_config def _storage_nbytes(x: torch.Tensor | Tensor) -> int: if isinstance(x, Tensor): data = x.storage.data return int(data.numel() * data.element_size()) return int(x.numel() * x.element_size()) def estimate_benchmark_work(c: MLPConfig, prepared: PreparedCase) -> tuple[int, int]: slice_sizes = prepared.ragged_metadata.slice_sizes n_tokens = int(slice_sizes.sum().item()) active_slices = int((slice_sizes > 0).sum().item()) n_slices = prepared.ragged_metadata.n_slices k, n = c.hidden_size, c.intermediate_size out_n = n // prepared.fused_activation.specs.reduction_n active_slice_bytes = active_slices * sum( _storage_nbytes(t) // n_slices for t in (prepared.w, prepared.w_scale, prepared.bias)) flops = 2 * n_tokens * k * n nbytes = (n_tokens * k * prepared.x.element_size() + active_slice_bytes + n_tokens * out_n * torch.empty( (), dtype=prepared.out_dtype).element_size()) return flops, nbytes def benchmark_kernel(prepared: PreparedCase, kernel, flops: int, nbytes: int) -> tuple[float, float]: precision_config = make_precision_config(prepared) out = make_output_buffer(prepared) ms = do_bench_cudagraph(lambda: run_kernel(prepared, kernel, precision_config, out)) seconds = ms * 1e-3 return flops * 1e-12 / seconds, nbytes * 1e-12 / seconds # ===-----------------------------------------------------------------------===# # Unit Tests # ===-----------------------------------------------------------------------===# GPT_OSS_120B_CONFIG = MLPConfig( name="gpt-oss-120b", num_experts=128, experts_per_token=4, num_expert_shards=8, hidden_size=2880, intermediate_size=2 * 2880, ) def is_blackwell(): return (triton.runtime.driver.active.get_current_target().backend == "cuda" and torch.cuda.get_device_capability()[0] == 10) @pytest.mark.parametrize("c", [GPT_OSS_120B_CONFIG]) @pytest.mark.parametrize("batch_size", get_batch_sizes(GPT_OSS_120B_CONFIG)) @pytest.mark.skipif(not is_blackwell(), reason="Gluon MoE BMM1 fused-gather is only supported on Blackwell GPUs") def test_op(c: MLPConfig, batch_size: tuple[int, ...]): prepared = prepare_case(c, batch_size, device=f"cuda:{torch.cuda.current_device()}") ref_y, ref_precision = run_provider(prepared, "reference") cand_y, cand_precision = run_provider(prepared, "example") description = f"{c.name}-mm1-bs{prepared.batch_size}" assert_close( ref_y.to(torch.float32), cand_y.to(torch.float32), maxtol=0.125, rmstol=None, description=f"{description}:out", verbose=False, ) ref_scale = ref_precision.flex_ctx.out_data.actual_scale cand_scale = cand_precision.flex_ctx.out_data.actual_scale if ref_scale is not None or cand_scale is not None: assert ref_scale is not None and cand_scale is not None assert_close( ref_scale.to(torch.float32), cand_scale.to(torch.float32), maxtol=1e-10, rmstol=1e-10, description=f"{description}:out_scale", verbose=False, ) # ===-----------------------------------------------------------------------===# # Benchmarking # ===-----------------------------------------------------------------------===# BENCH_TITLE = ("GPT-OSS-120B MoE MM1 " f"E={GPT_OSS_120B_CONFIG.num_experts} " f"EP={GPT_OSS_120B_CONFIG.experts_per_token} " f"ES={GPT_OSS_120B_CONFIG.num_expert_shards} " f"B={GPT_OSS_120B_CONFIG.hidden_size}x{GPT_OSS_120B_CONFIG.intermediate_size}") PEAK_TFLOPS = 5_000.0 PEAK_TBPS = 8.0 def _format_perf(result: tuple[float, float]) -> str: tflops, tbps = result return f"{tflops:8.2f} TFLOPS ({tflops / PEAK_TFLOPS:6.1%}) {tbps:6.2f} TBPS ({tbps / PEAK_TBPS:6.1%})" def bench(c: MLPConfig = GPT_OSS_120B_CONFIG, uniform_routing: bool = False): batch_sizes = get_batch_sizes(c) batch_width = max(len("batch_size"), *(len(str(bs)) for bs in batch_sizes)) perf_width = max( len("reference"), len(_format_perf((99999.99, 999.99))), ) print(BENCH_TITLE, flush=True) print(f"Peak: {PEAK_TFLOPS / 1000:g} PFLOPS, {PEAK_TBPS:g} TBPS", flush=True) print( f"{'batch_size':>{batch_width}} {'example':>{perf_width}} {'reference':>{perf_width}}", flush=True, ) print("-" * (batch_width + 2 + perf_width + 2 + perf_width), flush=True) device = f"cuda:{torch.cuda.current_device()}" for batch_size in batch_sizes: print(f"{batch_size:>{batch_width}} ", end="", flush=True) prepared = prepare_case( c, batch_size, device=device, uniform_routing=uniform_routing, ) flops, nbytes = estimate_benchmark_work(c, prepared) example = benchmark_kernel(prepared, matmul, flops, nbytes) print(f"{_format_perf(example):>{perf_width}} ", end="", flush=True) prepared = prepare_case( c, batch_size, device=device, uniform_routing=uniform_routing, reference=True, ) flops, nbytes = estimate_benchmark_work(c, prepared) reference = benchmark_kernel(prepared, reference_matmul, flops, nbytes) print(f"{_format_perf(reference):>{perf_width}}", flush=True) if __name__ == "__main__": bench(uniform_routing=False) ```