MOE Bmm1 Fused Gather

This example can be found at python/examples/gluon/05-moe-bmm1-fused-gather.py.

from dataclasses import dataclass, replace
from itertools import chain

import pytest
import torch
import triton
import triton.experimental.gluon as gluon
import triton.experimental.gluon.language as gl
import triton.experimental.gluon.language.nvidia.blackwell as blackwell
import triton.experimental.gluon.language.nvidia.blackwell.tma as tma
from triton.experimental.gluon.language.nvidia.blackwell import float2
import triton.experimental.gluon.language.nvidia.hopper.mbarrier as mbarrier
import triton.language.extra.libdevice as libdevice
from triton.testing import do_bench_cudagraph

from triton_kernels.distributed import make_expt_dict_uniform
from triton_kernels.matmul import (
    FlexCtx,
    FnSpecs,
    FusedActivation,
    PrecisionConfig,
    matmul as reference_matmul,
)
from triton_kernels.numerics import InFlexData, OutFlexData
from triton_kernels.numerics_details.mxfp import MXFP_BLOCK_SIZE, downcast_to_mxfp
from triton_kernels.swiglu import swiglu_fn
from triton_kernels.tensor import (
    FP4,
    RaggedTensorMetadata,
    Tensor,
    convert_layout,
    make_ragged_tensor_metadata,
    wrap_torch_tensor,
)
from triton_kernels.tensor_details.dtype import UINT8
from triton_kernels.tensor_details.layout import (
    BlackwellMX4ValueShuffledLayout,
    make_default_matmul_mxfp4_w_scale_layout,
)
from triton_kernels.testing import alloc_rand, assert_close
from triton_kernels.topk import topk

# ===-----------------------------------------------------------------------===#
# Device Code
# ===-----------------------------------------------------------------------===#


@gluon.jit
def advance(idx: gl.tensor, phase: gl.tensor, num_bufs: gl.constexpr) -> tuple[gl.tensor, gl.tensor]:
    next_idx = idx + 1
    wrap = next_idx == num_bufs
    return gl.where(wrap, 0, next_idx), gl.where(wrap, phase ^ 1, phase)


@gluon.jit
def unpack_block_schedule(schedule: gl.tensor) -> tuple[gl.tensor, gl.tensor]:
    return schedule & 0xFFFF, schedule >> 16


@gluon.jit
def banded_row_major(block_id, grid_m, GRID_N: gl.constexpr, BAND_N: gl.constexpr):
    if BAND_N >= GRID_N:
        return block_id // GRID_N, block_id % GRID_N

    full_band_tiles = grid_m * BAND_N
    n_full_bands = GRID_N // BAND_N
    full_band_work = n_full_bands * full_band_tiles

    if block_id < full_band_work:
        band_id = block_id // full_band_tiles
        within_band = block_id % full_band_tiles
        return within_band // BAND_N, band_id * BAND_N + (within_band % BAND_N)

    tail_n = GRID_N - n_full_bands * BAND_N
    tail_idx = block_id - full_band_work
    return tail_idx // tail_n, n_full_bands * BAND_N + (tail_idx % tail_n)


@gluon.jit
def apply_block_schedule(
    block_id: gl.tensor,
    grid_m: gl.tensor,
    GRID_N: gl.constexpr,
    slice_offsets: gl.tensor,
    block_schedule: gl.tensor,
    BAND_N: gl.constexpr,
) -> tuple[gl.tensor, gl.tensor, gl.tensor, gl.tensor]:
    schedule_pid_m, pid_n = banded_row_major(block_id, grid_m, GRID_N, BAND_N=BAND_N)

    slice_idx, pid_m = unpack_block_schedule(gl.load(block_schedule + schedule_pid_m))
    slice_offset = gl.load(slice_offsets + slice_idx)

    return pid_m, pid_n, slice_idx, slice_offset


@gluon.jit
def unswizzle_mx_scale(
    smem,
    SIZE_OUTER: gl.constexpr,
    SIZE_INNER: gl.constexpr,
    MXFP_BLOCK_SIZE: gl.constexpr,
):
    rows: gl.constexpr = smem.shape[1]
    cols: gl.constexpr = smem.shape[2] * smem.shape[3] * smem.shape[4]
    tiles: gl.constexpr = cols // (SIZE_OUTER * SIZE_INNER)
    smem = smem.reshape((rows, tiles, MXFP_BLOCK_SIZE, SIZE_OUTER // MXFP_BLOCK_SIZE, SIZE_INNER))
    smem = smem.permute((0, 3, 2, 1, 4))
    return smem.reshape((rows * SIZE_OUTER, cols // SIZE_OUTER))


@gluon.jit
def alloc_barrier_ring(num_bufs: gl.constexpr, two_ctas: gl.constexpr = False):
    bars = mbarrier.allocate_mbarrier(batch=num_bufs, two_ctas=two_ctas)
    for i in gl.static_range(num_bufs):
        mbarrier.init(bars.index(i), count=1)
    return bars


@gluon.jit
def alloc_ring_barriers(
    num_bufs: gl.constexpr,
    producer_two_ctas: gl.constexpr = False,
    consumer_two_ctas: gl.constexpr = False,
):
    return (
        alloc_barrier_ring(num_bufs, two_ctas=producer_two_ctas),
        alloc_barrier_ring(num_bufs, two_ctas=consumer_two_ctas),
    )


@gluon.jit
def pack_e4m3x2(values):
    return gl.inline_asm_elementwise(
        """
        {
            .reg .f32 lane<2>;
            mov.b64 {lane0, lane1}, $1;
            cvt.rn.satfinite.e4m3x2.f32 $0, lane1, lane0;
        }
        """,
        "=h,l",
        [values.value],
        dtype=gl.int16,
        is_pure=True,
        pack=1,
    )


@gluon.jit
def pack_u16x2(x0, x1):
    return gl.inline_asm_elementwise(
        """
        mov.b32 $0, { $1, $2 };
        """,
        "=r,h,h",
        [x0, x1],
        dtype=gl.int32,
        is_pure=True,
        pack=1,
    )


@gluon.jit
def pack_fp8x4(values):
    lhs, rhs = gl.split(values.reshape((values.shape[0], values.shape[1] // 2, 2)))
    return pack_u16x2(lhs, rhs)


@gluon.jit
def _split_m(values):
    return gl.split(values.reshape((2, values.shape[0] // 2, values.shape[1])).permute((1, 2, 0)))


@gluon.jit
def _split_m_float2(values):
    lhs, rhs = _split_m(values.value)
    return float2.Float2Tensor(lhs), float2.Float2Tensor(rhs)


@gluon.jit
def split_m_subtiles(values, subtile_factor: gl.constexpr):
    # For epilogue subtiling.
    subtiles = (values, )
    for split_level in gl.static_range(5):
        if (1 << split_level) < subtile_factor:
            next_subtiles = ()
            for subtile_idx in gl.static_range(1 << split_level):
                lhs, rhs = _split_m_float2(subtiles[subtile_idx])
                next_subtiles += (lhs, rhs)
            subtiles = next_subtiles
    return subtiles


@gluon.aggregate
class PartitionArgs:
    x_desc: tma.tensor_descriptor
    w_desc: tma.tensor_descriptor
    scale_desc: tma.tensor_descriptor
    out_desc: tma.tensor_descriptor
    x_scale_ptr: gl.tensor | gl.constexpr
    w_scale_ptr: gl.tensor | gl.constexpr
    out_scale_ptr: gl.tensor

    out_ptr: gl.tensor
    bias_ptr: gl.tensor
    bias_stride: gl.tensor
    gather_indx_ptr: gl.tensor
    x_slice_sizes: gl.tensor
    x_slice_offs: gl.tensor
    x_block_schedule: gl.tensor

    x_bufs: gl.shared_memory_descriptor
    x_empty_bars: gl.shared_memory_descriptor
    x_ready_bars: gl.shared_memory_descriptor
    x_num_bufs: gl.constexpr

    w_bufs: gl.shared_memory_descriptor
    w_scale_bufs: gl.shared_memory_descriptor
    w_empty_bars: gl.shared_memory_descriptor
    w_ready_bars: gl.shared_memory_descriptor
    w_num_bufs: gl.constexpr

    x_scale_tmem: blackwell.tensor_memory_descriptor
    w_scale_tmem: blackwell.tensor_memory_descriptor
    acc_bufs: blackwell.tensor_memory_descriptor
    acc_empty_bars: gl.shared_memory_descriptor
    acc_ready_bars: gl.shared_memory_descriptor
    acc_num_bufs: gl.constexpr

    grid_m: gl.tensor
    GRID_N: gl.constexpr
    K_TILES: gl.constexpr
    SCALE_FLAT_N: gl.constexpr
    SCALE_BLOCK_N_DIV: gl.constexpr
    num_blocks: gl.tensor

    NUM_SMS: gl.constexpr
    USE_2CTA: gl.constexpr
    BLOCK_M_PER_CTA: gl.constexpr
    BLOCK_M: gl.constexpr
    BLOCK_N: gl.constexpr
    BLOCK_K: gl.constexpr
    SCALE_SIZE_OUTER: gl.constexpr
    SCALE_SIZE_INNER: gl.constexpr
    MXFP_BLOCK_SIZE: gl.constexpr

    SWIGLU_ALPHA: gl.constexpr
    SWIGLU_LIMIT: gl.constexpr
    REDUCTION_N: gl.constexpr
    FLEXPOINT_SATURATE_INF: gl.constexpr

    SWIGLU_SUBTILE_FACTOR: gl.constexpr
    BAND_N: gl.constexpr
    X_GATHER_MULTICAST: gl.constexpr
    W_SCALE_MULTICAST: gl.constexpr
    FORCE_EPILOGUE_WARPS_N1: gl.constexpr

    @gluon.jit
    def apply_block_schedule(self, block_id: gl.tensor) -> tuple[gl.tensor, gl.tensor, gl.tensor, gl.tensor]:
        return apply_block_schedule(
            block_id=block_id,
            grid_m=self.grid_m,
            GRID_N=self.GRID_N,
            slice_offsets=self.x_slice_offs,
            block_schedule=self.x_block_schedule,
            BAND_N=self.BAND_N,
        )


@gluon.jit
def load_activations(p: PartitionArgs):
    local_cga_layout: gl.constexpr = ((0, 1), ) if p.USE_2CTA else ()
    offs_layout: gl.constexpr = gl.SliceLayout(
        dim=0,
        parent=gl.BlockedLayout([1, 4], [32, 1], [1, gl.num_warps()], [1, 0], cga_layout=local_cga_layout),
    )
    tile_x_bytes: gl.constexpr = p.x_desc.block_type.nbytes * (p.BLOCK_M_PER_CTA if p.USE_2CTA else p.BLOCK_M)

    idx = 0
    phase = 1
    issued = 0

    for block_id in range(gl.program_id(0), p.num_blocks, p.NUM_SMS):
        pid_m, _, slice_idx, slice_offset = p.apply_block_schedule(block_id)
        off_m = pid_m * p.BLOCK_M
        shape_m = gl.load(p.x_slice_sizes + slice_idx)
        offs_m = off_m + gl.arange(0, p.BLOCK_M, layout=offs_layout)
        mask_m = offs_m < shape_m
        offs_x_m = gl.load(
            p.gather_indx_ptr + slice_offset + offs_m,
            mask=mask_m,
            other=p.x_desc.shape[0],
        )

        for ki in range(p.K_TILES):
            off_k_x = ki * p.BLOCK_K

            empty_bar = p.x_empty_bars.index(idx)
            ready_bar = p.x_ready_bars.index(idx)
            x_buf = p.x_bufs.index(idx)

            mbarrier.wait(empty_bar, phase, pred=issued >= p.x_num_bufs)
            mbarrier.expect(ready_bar, tile_x_bytes)
            tma.async_gather(
                p.x_desc,
                offs_x_m,
                off_k_x,
                ready_bar,
                x_buf,
                multicast=p.USE_2CTA and p.X_GATHER_MULTICAST,
            )

            idx, phase = advance(idx, phase, p.x_num_bufs)
            issued += 1


@gluon.jit
def load_weights(p: PartitionArgs):
    tile_w_bytes: gl.constexpr = p.w_desc.nbytes_per_cta
    tile_scale_bytes: gl.constexpr = p.scale_desc.nbytes_per_cta
    bytes_per_stage: gl.constexpr = tile_w_bytes + tile_scale_bytes

    idx = 0
    phase = 1
    issued = 0

    for block_id in range(gl.program_id(0), p.num_blocks, p.NUM_SMS):
        _, pid_n, slice_idx, _ = p.apply_block_schedule(block_id)

        scale_idx = slice_idx * p.SCALE_FLAT_N + pid_n * p.SCALE_BLOCK_N_DIV
        for ki in range(p.K_TILES):
            off_k_scale = ki * p.BLOCK_K // (p.MXFP_BLOCK_SIZE * p.SCALE_SIZE_INNER)

            w_empty_bar = p.w_empty_bars.index(idx)
            w_ready_bar = p.w_ready_bars.index(idx)
            w_buf = p.w_bufs.index(idx)
            scale_buf = p.w_scale_bufs.index(idx)

            mbarrier.wait(w_empty_bar, phase, pred=issued >= p.w_num_bufs)
            mbarrier.expect(w_ready_bar, bytes_per_stage)
            tma.async_copy_global_to_shared(p.w_desc, [slice_idx, ki, pid_n, 0, 0], w_ready_bar, w_buf)
            tma.async_copy_global_to_shared(
                p.scale_desc,
                [0, scale_idx, off_k_scale, 0, 0],
                w_ready_bar,
                scale_buf,
                multicast=p.USE_2CTA and p.W_SCALE_MULTICAST,
            )

            idx, phase = advance(idx, phase, p.w_num_bufs)
            issued += 1


@gluon.jit
def mma_partition(p: PartitionArgs):
    x_idx = 0
    x_phase = 0
    w_idx = 0
    w_phase = 0
    mma_idx = 0
    mma_phase = 1

    for block_id in range(gl.program_id(0), p.num_blocks, p.NUM_SMS):
        acc_empty_bar = p.acc_empty_bars.index(mma_idx)
        acc_ready_bar = p.acc_ready_bars.index(mma_idx)
        acc_buf = p.acc_bufs.index(mma_idx)
        mbarrier.wait(acc_empty_bar, mma_phase)

        use_acc = False
        for _ in range(p.K_TILES):
            w_ready_bar = p.w_ready_bars.index(w_idx)
            w_empty_bar = p.w_empty_bars.index(w_idx)
            w_buf = p.w_bufs.index(w_idx)
            scale_buf = p.w_scale_bufs.index(w_idx)
            mbarrier.wait(w_ready_bar, w_phase)

            blackwell.tcgen05_copy(
                unswizzle_mx_scale(scale_buf, p.SCALE_SIZE_OUTER, p.SCALE_SIZE_INNER, p.MXFP_BLOCK_SIZE),
                p.w_scale_tmem,
            )

            x_ready_bar = p.x_ready_bars.index(x_idx)
            x_empty_bar = p.x_empty_bars.index(x_idx)
            x_buf = p.x_bufs.index(x_idx)
            mbarrier.wait(x_ready_bar, x_phase)

            blackwell.tcgen05_mma_scaled(
                w_buf.reshape((p.BLOCK_N, p.BLOCK_K // 2)),
                x_buf.permute((1, 0)),
                acc_buf,
                p.w_scale_tmem,
                p.x_scale_tmem,
                a_type="e2m1",
                b_type="e4m3",
                use_acc=use_acc,
            )
            blackwell.tcgen05_commit(x_empty_bar)
            blackwell.tcgen05_commit(w_empty_bar)

            x_idx, x_phase = advance(x_idx, x_phase, p.x_num_bufs)
            w_idx, w_phase = advance(w_idx, w_phase, p.w_num_bufs)
            use_acc = True

        blackwell.tcgen05_commit(acc_ready_bar)
        mma_idx, mma_phase = advance(mma_idx, mma_phase, p.acc_num_bufs)


@gluon.jit
def store_packed_out(
    p: PartitionArgs,
    packed_out,
    off_m,
    out_off_n_packed,
    shape_m,
    slice_offset,
):
    values = pack_fp8x4(packed_out)
    layout: gl.constexpr = values.type.layout
    offs_m = off_m + gl.arange(0, values.shape[0], layout=gl.SliceLayout(1, layout))
    offs_n = out_off_n_packed + gl.arange(0, values.shape[1], layout=gl.SliceLayout(0, layout))
    mask_m = gl.expand_dims(offs_m < shape_m, 1)
    mask_n = gl.expand_dims(offs_n < (p.out_desc.shape[1] + 3) // 4, 0)
    mask = mask_m & mask_n
    ptrs = p.out_ptr.cast(gl.pointer_type(gl.int32), bitcast=True)
    ptrs = ptrs + gl.expand_dims(slice_offset + offs_m, 1) * (p.out_desc.strides[0] // 4)
    ptrs = ptrs + gl.expand_dims(offs_n, 0) * p.out_desc.strides[1]
    gl.store(ptrs, values, mask=mask)


@gluon.jit
def _swiglu_step1(acc_packed, limit):
    gelu, linear = float2.unpack2(acc_packed)
    gelu = gl.minimum(gelu.to(gl.float32), limit)
    linear = gl.minimum(gl.maximum(linear.to(gl.float32), -limit), limit)
    return gelu, linear


@gluon.jit
def _swiglu_step2(gelu, linear, alpha):
    den = 1.0 + libdevice.exp(-alpha * gelu)
    activated = gelu / den
    activated_packed = float2.pack(activated, axis=1)
    linear_packed = float2.pack(linear, axis=1)
    return float2.fma(activated_packed, linear_packed, activated_packed)


@gluon.jit
def pack_fp8_out_fragment(out_packed, out_recip):
    scaled_out_packed = out_packed * float2.full_like(out_packed, out_recip)
    return pack_e4m3x2(scaled_out_packed)


@gluon.jit
def get_store_layout(p: PartitionArgs):
    frag_rows: gl.constexpr = p.BLOCK_M // p.SWIGLU_SUBTILE_FACTOR
    local_cga_layout: gl.constexpr = ((0, 1), ) if p.USE_2CTA else ()
    return gl.BlockedLayout(
        [frag_rows // gl.num_warps(), 2],
        [1, 32],
        [gl.num_warps(), 1],
        [1, 0],
        cga_layout=local_cga_layout,
    )


@gluon.jit
def epilogue_direct_store(
    p: PartitionArgs,
    acc_packed,
    out_recip,
    off_m,
    out_off_n_packed,
    shape_m,
    slice_offset,
    store_layout: gl.constexpr,
):
    frag_rows: gl.constexpr = p.BLOCK_M // p.SWIGLU_SUBTILE_FACTOR
    acc_packed_subtiles = split_m_subtiles(acc_packed, p.SWIGLU_SUBTILE_FACTOR)
    for frag_idx in gl.static_range(p.SWIGLU_SUBTILE_FACTOR):
        gelu, linear = _swiglu_step1(acc_packed_subtiles[frag_idx], p.SWIGLU_LIMIT)
        out_packed = _swiglu_step2(gelu, linear, p.SWIGLU_ALPHA)
        packed_fp8 = gl.convert_layout(pack_fp8_out_fragment(out_packed, out_recip), store_layout)
        store_packed_out(
            p,
            packed_fp8,
            off_m + frag_idx * frag_rows,
            out_off_n_packed,
            shape_m,
            slice_offset,
        )


@gluon.jit
def apply_bias_and_scale(
    p: PartitionArgs,
    idx,
    phase,
    pid_n,
    slice_idx,
    split_layout: gl.constexpr,
    bias_layout: gl.constexpr,
    acc_scale,
):
    off_n = pid_n * p.BLOCK_N
    acc_empty_bar = p.acc_empty_bars.index(idx)
    acc_ready_bar = p.acc_ready_bars.index(idx)
    acc_buf = p.acc_bufs.index(idx)

    offs_bias_n = off_n + gl.arange(0, p.BLOCK_N, layout=bias_layout)
    bias = gl.convert_layout(
        gl.expand_dims(gl.load(p.bias_ptr + slice_idx * p.bias_stride + offs_bias_n), axis=0),
        split_layout,
    )
    mbarrier.wait(acc_ready_bar, phase)
    acc_regs = acc_buf.load().permute((1, 0))
    mbarrier.arrive(acc_empty_bar)
    idx, phase = advance(idx, phase, p.acc_num_bufs)
    acc = gl.convert_layout(acc_regs, split_layout)
    acc_packed = float2.pack(acc, axis=1)
    bias_packed = float2.pack(bias, axis=1)
    bias_packed = float2.Float2Tensor(gl.convert_layout(bias_packed.value, acc_packed.value.type.layout))
    acc_packed = float2.fma(acc_packed, float2.full_like(acc_packed, acc_scale), bias_packed)
    return idx, phase, acc_packed


@gluon.jit
def epilogue_partition(p: PartitionArgs):
    idx = 0
    phase = 0

    x_scale = 1.0 if p.x_scale_ptr is None else gl.load(p.x_scale_ptr)
    w_scale = 1.0 if p.w_scale_ptr is None else gl.load(p.w_scale_ptr)
    acc_scale = x_scale * w_scale
    out_recip = 1.0 / gl.load(p.out_scale_ptr)

    num_warps: gl.constexpr = gl.num_warps()
    warps_n: gl.constexpr = 1 if p.FORCE_EPILOGUE_WARPS_N1 else (2 if num_warps >= 8 and p.BLOCK_N >= 256 else 1)
    split_cga_layout: gl.constexpr = ((0, 1), ) if p.USE_2CTA else ()
    split_layout: gl.constexpr = gl.BlockedLayout(
        [1, 4],
        [1, 32],
        [num_warps // warps_n, warps_n],
        [1, 0],
        cga_layout=split_cga_layout,
    )
    bias_layout: gl.constexpr = gl.SliceLayout(0, split_layout)
    store_layout: gl.constexpr = get_store_layout(p)

    for block_id in range(gl.program_id(0), p.num_blocks, p.NUM_SMS):
        pid_m, pid_n, slice_idx, slice_offset = p.apply_block_schedule(block_id)
        off_m = pid_m * p.BLOCK_M
        shape_m = gl.load(p.x_slice_sizes + slice_idx)
        out_off_n_packed = pid_n * (p.BLOCK_N // p.REDUCTION_N // 4)
        idx, phase, acc_packed = apply_bias_and_scale(
            p,
            idx,
            phase,
            pid_n,
            slice_idx,
            split_layout,
            bias_layout,
            acc_scale,
        )
        epilogue_direct_store(
            p,
            acc_packed,
            out_recip,
            off_m,
            out_off_n_packed,
            shape_m,
            slice_offset,
            store_layout,
        )


@gluon.jit
def ws_matmul_kernel(
    x_desc: tma.tensor_descriptor,
    w_desc: tma.tensor_descriptor,
    scale_desc: tma.tensor_descriptor,
    out_desc: tma.tensor_descriptor,
    out_ptr: gl.tensor,
    #
    bias_ptr: gl.tensor,
    bias_stride: gl.tensor,
    #
    gather_indx_ptr: gl.tensor,
    #
    x_slice_sizes: gl.tensor,
    x_slice_offs: gl.tensor,
    x_block_offs: gl.tensor,
    x_block_schedule: gl.tensor,
    #
    x_scale_ptr: gl.tensor,
    w_scale_ptr: gl.tensor,
    out_scale_ptr: gl.tensor,
    #
    M: gl.constexpr,
    N: gl.constexpr,
    K: gl.constexpr,
    NUM_SLICES: gl.constexpr,
    #
    SWIGLU_ALPHA: gl.constexpr,
    SWIGLU_LIMIT: gl.constexpr,
    REDUCTION_N: gl.constexpr,
    #
    FLEXPOINT_SATURATE_INF: gl.constexpr,
    #
    BLOCK_M: gl.constexpr,
    BLOCK_N: gl.constexpr,
    BLOCK_K: gl.constexpr,
    NUM_SMS: gl.constexpr,
    X_NUM_BUFS: gl.constexpr,
    W_NUM_BUFS: gl.constexpr,
    ACC_NUM_BUFS: gl.constexpr,
    LOAD_ACTIVATION_WARPS: gl.constexpr,
    LOAD_WEIGHT_WARPS: gl.constexpr,
    MMA_WARPS: gl.constexpr,
    LOAD_ACTIVATION_REGS: gl.constexpr,
    LOAD_WEIGHT_REGS: gl.constexpr,
    MMA_REGS: gl.constexpr,
    SWIGLU_SUBTILE_FACTOR: gl.constexpr,
    BAND_N: gl.constexpr,
    X_GATHER_MULTICAST: gl.constexpr,
    W_SCALE_MULTICAST: gl.constexpr,
    FORCE_EPILOGUE_WARPS_N1: gl.constexpr,
    SCALE_SIZE_OUTER: gl.constexpr,
    SCALE_SIZE_INNER: gl.constexpr,
    MXFP_BLOCK_SIZE: gl.constexpr,
):
    use_2cta: gl.constexpr = gl.num_ctas() > 1
    gl.static_assert(gl.num_ctas() == 1 or gl.num_ctas() == 2, "kernel supports at most 2 CTAs")
    gl.static_assert(not use_2cta or BLOCK_N >= 256, "2CTA path requires at least 128 columns per CTA")

    grid_m = gl.load(x_block_offs + NUM_SLICES)
    grid_n: gl.constexpr = triton.cdiv(N, BLOCK_N)
    k_tiles: gl.constexpr = triton.cdiv(K, BLOCK_K)
    scale_flat_n: gl.constexpr = N // SCALE_SIZE_OUTER
    scale_block_n_div: gl.constexpr = BLOCK_N // SCALE_SIZE_OUTER
    num_blocks = grid_m * grid_n

    scale_k: gl.constexpr = BLOCK_K // MXFP_BLOCK_SIZE
    block_m_per_cta: gl.constexpr = BLOCK_M // gl.num_ctas()
    x_scale_layout: gl.constexpr = blackwell.TensorMemoryScalesLayout(cga_layout=((0, 0), ) if use_2cta else (), )
    w_scale_layout: gl.constexpr = blackwell.TensorMemoryScalesLayout(cga_layout=((1, 0), ) if use_2cta else (), )
    mma_block_col: gl.constexpr = min(128, BLOCK_N // gl.num_ctas())
    acc_layout: gl.constexpr = blackwell.TensorMemoryLayout(
        [mma_block_col, BLOCK_M],
        col_stride=1,
        cga_layout=((1, 0), ) if use_2cta else (),
        two_ctas=use_2cta,
    )

    x_num_bufs: gl.constexpr = X_NUM_BUFS
    x_bufs = gl.allocate_shared_memory(
        x_desc.dtype,
        [x_num_bufs, BLOCK_M, x_desc.block_type.shape[1]],
        x_desc.layout,
    )
    x_empty_bars, x_ready_bars = alloc_ring_barriers(x_num_bufs, consumer_two_ctas=use_2cta)

    w_num_bufs: gl.constexpr = W_NUM_BUFS
    w_bufs = gl.allocate_shared_memory(
        w_desc.dtype,
        [w_num_bufs] + w_desc.block_type.shape,
        w_desc.layout,
    )
    w_scale_bufs = gl.allocate_shared_memory(
        scale_desc.dtype,
        [w_num_bufs] + scale_desc.block_type.shape,
        scale_desc.layout,
    )
    w_empty_bars, w_ready_bars = alloc_ring_barriers(w_num_bufs, consumer_two_ctas=use_2cta)

    x_scale_tmem = blackwell.allocate_tensor_memory(gl.uint8, [BLOCK_M, scale_k], x_scale_layout)
    w_scale_tmem = blackwell.allocate_tensor_memory(gl.uint8, [BLOCK_N, scale_k], w_scale_layout)

    acc_num_bufs: gl.constexpr = ACC_NUM_BUFS
    acc_tmem = blackwell.allocate_tensor_memory(
        gl.float32,
        [acc_num_bufs, BLOCK_N, BLOCK_M],
        acc_layout,
    )
    acc_empty_bars, acc_ready_bars = alloc_ring_barriers(acc_num_bufs, producer_two_ctas=use_2cta)

    x_scale_tmem.store(gl.full((BLOCK_M, scale_k), 127, dtype=gl.uint8, layout=x_scale_tmem.get_reg_layout()))

    p = PartitionArgs(
        x_desc=x_desc,
        w_desc=w_desc,
        scale_desc=scale_desc,
        out_desc=out_desc,
        x_scale_ptr=x_scale_ptr,
        w_scale_ptr=w_scale_ptr,
        out_scale_ptr=out_scale_ptr,
        #
        out_ptr=out_ptr,
        bias_ptr=bias_ptr,
        bias_stride=bias_stride,
        gather_indx_ptr=gather_indx_ptr,
        x_slice_sizes=x_slice_sizes,
        x_slice_offs=x_slice_offs,
        x_block_schedule=x_block_schedule,
        #
        x_bufs=x_bufs,
        x_empty_bars=x_empty_bars,
        x_ready_bars=x_ready_bars,
        x_num_bufs=x_num_bufs,
        #
        w_bufs=w_bufs,
        w_scale_bufs=w_scale_bufs,
        w_empty_bars=w_empty_bars,
        w_ready_bars=w_ready_bars,
        w_num_bufs=w_num_bufs,
        #
        x_scale_tmem=x_scale_tmem,
        w_scale_tmem=w_scale_tmem,
        acc_bufs=acc_tmem,
        acc_empty_bars=acc_empty_bars,
        acc_ready_bars=acc_ready_bars,
        acc_num_bufs=acc_num_bufs,
        #
        grid_m=grid_m,
        GRID_N=grid_n,
        K_TILES=k_tiles,
        SCALE_FLAT_N=scale_flat_n,
        SCALE_BLOCK_N_DIV=scale_block_n_div,
        num_blocks=num_blocks,
        #
        NUM_SMS=NUM_SMS,
        USE_2CTA=use_2cta,
        BLOCK_M_PER_CTA=block_m_per_cta,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_K=BLOCK_K,
        SCALE_SIZE_OUTER=SCALE_SIZE_OUTER,
        SCALE_SIZE_INNER=SCALE_SIZE_INNER,
        MXFP_BLOCK_SIZE=MXFP_BLOCK_SIZE,
        #
        SWIGLU_ALPHA=SWIGLU_ALPHA,
        SWIGLU_LIMIT=SWIGLU_LIMIT,
        REDUCTION_N=REDUCTION_N,
        FLEXPOINT_SATURATE_INF=FLEXPOINT_SATURATE_INF,
        #
        SWIGLU_SUBTILE_FACTOR=SWIGLU_SUBTILE_FACTOR,
        BAND_N=BAND_N,
        X_GATHER_MULTICAST=X_GATHER_MULTICAST,
        W_SCALE_MULTICAST=W_SCALE_MULTICAST,
        FORCE_EPILOGUE_WARPS_N1=FORCE_EPILOGUE_WARPS_N1,
    )

    gl.warp_specialize(
        [
            (epilogue_partition, (p, )),
            (load_activations, (p, )),
            (load_weights, (p, )),
            (mma_partition, (p, )),
        ],
        [LOAD_ACTIVATION_WARPS, LOAD_WEIGHT_WARPS, MMA_WARPS],
        [LOAD_ACTIVATION_REGS, LOAD_WEIGHT_REGS, MMA_REGS],
    )


# ===-----------------------------------------------------------------------===#
# Host Code
# ===-----------------------------------------------------------------------===#


def make_tensor_descriptor(
        t: torch.Tensor | Tensor,
        block_shape: tuple[int, ...],
        *,
        layout_block_shape: tuple[int, ...] | None = None,
        cga_layout: tuple[tuple[int, ...], ...] = (),
):
    from triton.experimental.gluon.nvidia.hopper import TensorDescriptor

    ptr = t if isinstance(t, torch.Tensor) else t.storage.data
    shape = list(ptr.shape)
    strides = list(ptr.stride())
    desc_block_shape = list(block_shape)
    layout_shape = list(layout_block_shape or block_shape)

    if isinstance(t, Tensor) and t.dtype == FP4:
        assert isinstance(t.storage.layout, BlackwellMX4ValueShuffledLayout)
        assert layout_block_shape is None
        desc_block_shape = t.storage.layout.swizzle_block_shape(desc_block_shape)
        desc_block_shape[strides.index(1)] //= 2
        layout_shape = desc_block_shape

    rank = len(layout_shape)
    if t.dtype == FP4:
        assert rank == 5
        layout = gl.NVMMASharedLayout(
            swizzle_byte_width=128,
            element_bitwidth=8,
            rank=rank,
            fp4_padded=True,
            cga_layout=cga_layout,
        )
    elif t.dtype == UINT8:
        assert rank == 5
        layout = gl.NVMMASharedLayout(
            swizzle_byte_width=0,
            element_bitwidth=8,
            rank=rank,
            cga_layout=cga_layout,
        )
    elif t.dtype == torch.float32:
        assert rank == 2
        layout = gl.NVMMASharedLayout.get_default_for(
            layout_shape,
            torch.float32,
            cga_layout=cga_layout,
        )
    else:
        assert t.dtype == torch.float8_e4m3fn
        layout = gl.NVMMASharedLayout(
            swizzle_byte_width=layout_shape[-1],
            element_bitwidth=8,
            rank=rank,
            cga_layout=cga_layout,
        )
    return TensorDescriptor(ptr, shape, strides, desc_block_shape, layout)


@dataclass(frozen=True, slots=True)
class KernelConfig:
    BLOCK_M: int = 128
    BLOCK_N: int = 256
    BLOCK_K: int = 128

    NUM_CTAS: int = 1
    X_NUM_BUFS: int = 5
    W_NUM_BUFS: int = 4
    ACC_NUM_BUFS: int = 1

    NUM_WARPS: int = 8
    LOAD_ACTIVATION_WARPS: int = 4
    LOAD_WEIGHT_WARPS: int = 1
    MMA_WARPS: int = 1

    SWIGLU_SUBTILE_FACTOR: int = 8
    BAND_N: int = 20
    X_GATHER_MULTICAST: bool = True
    W_SCALE_MULTICAST: bool = True
    FORCE_EPILOGUE_WARPS_N1: bool = False

    LOAD_ACTIVATION_REGS: int = 112
    LOAD_WEIGHT_REGS: int = 48
    MMA_REGS: int = 48
    MAXNREG: int = None
    OCCUPANCY: int = 1

    MXFP_BLOCK_SIZE: int = 32
    SCALE_SIZE_OUTER: int = 128
    SCALE_SIZE_INNER: int = 4

    def get_x_tile_smem(self) -> int:
        return self.BLOCK_M * self.BLOCK_K

    def get_w_tile_smem(self) -> int:
        return self.BLOCK_N * self.BLOCK_K

    def get_w_mx_tile_smem(self) -> int:
        return self.get_w_tile_smem() // self.MXFP_BLOCK_SIZE

    def get_c_tile_smem(self, reduction_n: int) -> int:
        return (self.BLOCK_M // self.SWIGLU_SUBTILE_FACTOR) * (self.BLOCK_N // reduction_n)


def _select_base_config(slice_size: int) -> KernelConfig:
    if slice_size <= 14:
        block_m = 16
    elif slice_size <= 32:
        block_m = 32
    elif slice_size <= 64:
        block_m = 64
    else:
        block_m = 128

    if slice_size <= 64:
        x_num_bufs, w_num_bufs = {
            16: (10, 5),
            32: (5, 5),
            64: (4, 4),
        }[block_m]
        return KernelConfig(
            BLOCK_M=block_m,
            BLOCK_N=128,
            X_NUM_BUFS=x_num_bufs,
            W_NUM_BUFS=w_num_bufs,
            SWIGLU_SUBTILE_FACTOR=min(8, block_m // 8),
            OCCUPANCY=2,
            MAXNREG=64,
            LOAD_ACTIVATION_REGS=48,
            LOAD_WEIGHT_REGS=32,
            MMA_REGS=32,
        )

    return KernelConfig(
        BLOCK_M=block_m,
        SWIGLU_SUBTILE_FACTOR=min(8, block_m // 8),
    )


def select_kernel_config(slice_size: int) -> KernelConfig:
    p = _select_base_config(slice_size)

    if p.BLOCK_M == 32 and p.BLOCK_N == 128 and slice_size in (16, 20, 24, 32):
        p = replace(
            p,
            BLOCK_N=256,
            NUM_CTAS=2,
            NUM_WARPS=4,
            W_NUM_BUFS=5,
            SWIGLU_SUBTILE_FACTOR=1,
            LOAD_ACTIVATION_WARPS=2,
            LOAD_WEIGHT_WARPS=1,
            MMA_WARPS=1,
            LOAD_ACTIVATION_REGS=40,
            LOAD_WEIGHT_REGS=32,
            MMA_REGS=32,
            MAXNREG=52,
            BAND_N=32,
            FORCE_EPILOGUE_WARPS_N1=True,
        )
        if slice_size == 16:
            p = replace(p, X_NUM_BUFS=5)
        elif slice_size in (20, 24):
            p = replace(p, X_NUM_BUFS=6)
        else:
            p = replace(p, X_NUM_BUFS=6, X_GATHER_MULTICAST=False, W_SCALE_MULTICAST=False)
    elif 36 <= slice_size <= 72:
        p = replace(
            p,
            BLOCK_M=64,
            BLOCK_N=256,
            NUM_CTAS=2,
            OCCUPANCY=2,
            X_NUM_BUFS=6,
            W_NUM_BUFS=5,
            SWIGLU_SUBTILE_FACTOR=4,
            LOAD_ACTIVATION_REGS=64,
            LOAD_WEIGHT_REGS=48,
            MMA_REGS=48,
            MAXNREG=64,
        )
    elif p.BLOCK_M == 128 and p.BLOCK_N == 256 and slice_size >= 80:
        p = replace(p, BLOCK_N=512, NUM_CTAS=2, W_NUM_BUFS=5)

    if p.BLOCK_M == 32 and p.BLOCK_N == 256 and p.NUM_CTAS == 2 and slice_size <= 32:
        return replace(p, BAND_N=32)
    if p.BLOCK_M == 64 and p.BLOCK_N == 256 and p.NUM_CTAS == 2 and 36 <= slice_size <= 72:
        return replace(p, BAND_N=26)
    if slice_size < 32:
        return replace(p, BAND_N=22)
    if slice_size < 416:
        return replace(p, BAND_N=18)
    return replace(p, BAND_N=26)


def matmul(
    a: torch.Tensor,
    b: torch.Tensor | Tensor,
    bias: torch.Tensor,
    a_ragged_metadata: RaggedTensorMetadata,
    gather_indx: torch.Tensor,
    precision_config: PrecisionConfig,
    c: torch.Tensor,
    fused_activation: FusedActivation,
    p: KernelConfig | None = None,
):
    specs = fused_activation.specs
    assert specs.name == "swiglu"
    reduction_n = specs.reduction_n
    swiglu_alpha, swiglu_limit = fused_activation.fn_args

    b_mx_scales = precision_config.b_mx_scale

    out_dtype = precision_config.out_dtype
    assert out_dtype is not None

    assert c.ndim == 2

    flex_ctx = precision_config.flex_ctx

    assert a.ndim == 2
    _, k = a.shape
    _, _, n = b.shape
    m = gather_indx.shape[0]

    p = p or select_kernel_config(a_ragged_metadata.expected_slice_size)
    assert isinstance(b, Tensor)
    assert isinstance(b.storage.layout, BlackwellMX4ValueShuffledLayout)
    assert b.storage.layout.block_k == p.BLOCK_K
    assert b.storage.layout.block_n == p.BLOCK_N
    x_block_offs = a_ragged_metadata.block_offs(p.BLOCK_M)
    x_block_schedule = a_ragged_metadata.block_schedule(p.BLOCK_M)
    expected_grid_m = a_ragged_metadata.n_blocks(a_ragged_metadata.n_slices, m, p.BLOCK_M)
    grid_n = triton.cdiv(n, p.BLOCK_N)
    sms = torch.cuda.get_device_properties(bias.device).multi_processor_count
    sms *= p.OCCUPANCY
    launch_grid = max(1, min(max(1, sms // p.NUM_CTAS), expected_grid_m * grid_n))
    grid = (launch_grid, )
    if p.NUM_CTAS == 1:
        acc_cga_layout = ()
    elif p.NUM_CTAS == 2:
        acc_cga_layout = ((1, 0), )
    else:
        raise ValueError(f"unsupported CTA count: {p.NUM_CTAS}")

    x_desc = make_tensor_descriptor(
        a,
        (1, p.BLOCK_K),
        layout_block_shape=(p.BLOCK_M, p.BLOCK_K),
        cga_layout=tuple((basis[0], 0) for basis in acc_cga_layout),
    )
    w_desc = make_tensor_descriptor(
        b,
        (1, p.BLOCK_K, p.BLOCK_N),
        # Sharded weight tiles use the physical [1, 1, 1, N, K/2] MX4 shuffled block layout.
        cga_layout=tuple((0, 0, 0, basis[0], 0) for basis in acc_cga_layout),
    )
    scale_desc = make_tensor_descriptor(
        b_mx_scales,
        (
            1,
            p.BLOCK_N // p.SCALE_SIZE_OUTER,
            p.BLOCK_K // p.MXFP_BLOCK_SIZE // p.SCALE_SIZE_INNER,
            2,
            256,
        ),
        # Weight scale tiles use the physical [1, N//128, K//(32*4), 2, 256] layout.
        cga_layout=tuple((0, basis[0], 0, 0, 0) for basis in acc_cga_layout),
    )
    # The output descriptor is only used for shape/stride metadata during
    # direct stores, so cap the layout width to a legal FP8 swizzle size.
    out_desc = make_tensor_descriptor(c, (p.BLOCK_M, min(p.BLOCK_N // reduction_n, 128)))

    ws_matmul_kernel[grid](
        x_desc=x_desc,
        w_desc=w_desc,
        scale_desc=scale_desc,
        out_desc=out_desc,
        out_ptr=c,
        #
        bias_ptr=bias,
        bias_stride=bias.stride(0),
        #
        gather_indx_ptr=gather_indx,
        #
        x_slice_sizes=a_ragged_metadata.slice_sizes,
        x_slice_offs=a_ragged_metadata.slice_offs,
        x_block_offs=x_block_offs,
        x_block_schedule=x_block_schedule,
        #
        x_scale_ptr=flex_ctx.lhs_data.scale,
        w_scale_ptr=flex_ctx.rhs_data.scale,
        out_scale_ptr=flex_ctx.out_data.expected_scale,
        #
        M=m,
        N=n,
        K=k,
        NUM_SLICES=a_ragged_metadata.n_slices,
        #
        SWIGLU_ALPHA=swiglu_alpha,
        SWIGLU_LIMIT=swiglu_limit,
        REDUCTION_N=reduction_n,
        #
        FLEXPOINT_SATURATE_INF=precision_config.flexpoint_saturate_inf,
        #
        BLOCK_M=p.BLOCK_M,
        BLOCK_N=p.BLOCK_N,
        BLOCK_K=p.BLOCK_K,
        NUM_SMS=launch_grid,
        X_NUM_BUFS=p.X_NUM_BUFS,
        W_NUM_BUFS=p.W_NUM_BUFS,
        ACC_NUM_BUFS=p.ACC_NUM_BUFS,
        LOAD_ACTIVATION_WARPS=p.LOAD_ACTIVATION_WARPS,
        LOAD_WEIGHT_WARPS=p.LOAD_WEIGHT_WARPS,
        MMA_WARPS=p.MMA_WARPS,
        LOAD_ACTIVATION_REGS=p.LOAD_ACTIVATION_REGS,
        LOAD_WEIGHT_REGS=p.LOAD_WEIGHT_REGS,
        MMA_REGS=p.MMA_REGS,
        SWIGLU_SUBTILE_FACTOR=p.SWIGLU_SUBTILE_FACTOR,
        BAND_N=p.BAND_N,
        X_GATHER_MULTICAST=p.X_GATHER_MULTICAST,
        W_SCALE_MULTICAST=p.W_SCALE_MULTICAST,
        FORCE_EPILOGUE_WARPS_N1=p.FORCE_EPILOGUE_WARPS_N1,
        #
        SCALE_SIZE_OUTER=p.SCALE_SIZE_OUTER,
        SCALE_SIZE_INNER=p.SCALE_SIZE_INNER,
        MXFP_BLOCK_SIZE=p.MXFP_BLOCK_SIZE,
        #
        num_warps=p.NUM_WARPS,
        num_ctas=p.NUM_CTAS,
        maxnreg=p.MAXNREG,
    )

    return c


# ===-----------------------------------------------------------------------===#
# Benchmark and Testing Helpers
# ===-----------------------------------------------------------------------===#


@dataclass(frozen=True, slots=True)
class MLPConfig:
    name: str
    num_experts: int
    experts_per_token: int
    num_expert_shards: int
    hidden_size: int
    intermediate_size: int


def get_batch_sizes(c: MLPConfig) -> tuple[int, ...]:
    batch_per_expert = tuple(chain.from_iterable(range(2**(2 + k), 2**(3 + k), min(2**k, 32)) for k in range(8)))
    return tuple(batch_per_expert * c.num_experts // c.experts_per_token for batch_per_expert in batch_per_expert)


@dataclass(frozen=True, slots=True)
class PreparedCase:
    batch_size: int
    local_rank: int
    x: torch.Tensor
    w: Tensor
    w_scale: Tensor
    bias: torch.Tensor
    ragged_metadata: RaggedTensorMetadata
    gather_indx: torch.Tensor
    fused_activation: FusedActivation
    x_scale: torch.Tensor
    y_scale: torch.Tensor
    out_shape: tuple[int, int]
    out_dtype: torch.dtype


def alloc_randn(shape: tuple[int, ...], dtype: torch.dtype, device: str) -> torch.Tensor:
    if dtype.itemsize == 1:
        return alloc_rand(shape, device=device, dtype=dtype)
    return torch.randn(shape, device=device, dtype=dtype)


def alloc_randn_fp4(shape: tuple[int, ...], device: str, p: KernelConfig | None) -> tuple[Tensor, Tensor]:
    if p is not None:
        block_k, block_n, num_warps = p.BLOCK_K, p.BLOCK_N, p.NUM_WARPS
    else:
        block_k, block_n, num_warps = 128, 256, 8

    data = alloc_randn(shape, torch.bfloat16, device)
    data, scale = downcast_to_mxfp(data, FP4, axis=1)  # type: ignore[arg-type]
    data_layout = BlackwellMX4ValueShuffledLayout(block_k=block_k, block_n=block_n)
    scale_layout = make_default_matmul_mxfp4_w_scale_layout(mx_axis=1, num_warps=num_warps)
    data = convert_layout(wrap_torch_tensor(data, dtype=FP4), data_layout)
    scale = convert_layout(wrap_torch_tensor(scale), scale_layout)
    return data, scale


def make_prod_like_logits(
    batch_size: int,
    num_experts: int,
    experts_per_token: int,
    device: str,
    dtype: torch.dtype = torch.float16,
    *,
    zipf_alpha: float = 1.10,
    num_clusters: int = 16,
    cluster_boost: float = 1.25,
    gumbel_scale: float = 0.75,
    batch_hot_experts: int = 4,
    batch_hot_boost: float = 0.6,
) -> torch.Tensor:
    # Stable expert popularity: a few hot experts, long tail.
    ranks = torch.arange(1, num_experts + 1, device=device, dtype=torch.float32)
    ranked_probs = ranks.pow(-zipf_alpha)
    ranked_probs /= ranked_probs.sum()

    # Randomize which expert ids are hot so shard/id layout is not special.
    perm = torch.randperm(num_experts, device=device)
    expert_probs = torch.empty_like(ranked_probs)
    expert_probs[perm] = ranked_probs

    logits = expert_probs.clamp_min(1e-12).log()[None, :].expand(batch_size, -1).clone()

    # Token locality: each token belongs to a synthetic topic/cluster with preferred experts.
    cluster_size = min(num_experts, max(2 * experts_per_token, num_experts // 16))
    cluster_experts = torch.stack(
        [torch.multinomial(expert_probs, cluster_size, replacement=False) for _ in range(num_clusters)])
    token_cluster = torch.randint(num_clusters, (batch_size, ), device=device)
    rows = torch.arange(batch_size, device=device)[:, None]
    logits[rows, cluster_experts[token_cluster]] += cluster_boost

    # Batch burstiness: a few experts are hotter for this batch.
    if batch_hot_experts > 0:
        hot = torch.multinomial(expert_probs, batch_hot_experts, replacement=False)
        logits[:, hot] += batch_hot_boost

    # Gumbel noise makes top-k behave like weighted sampling without replacement.
    noise = -torch.empty_like(logits).exponential_().log()
    logits += gumbel_scale * noise

    return logits.to(dtype)


def init_routing_data(c: MLPConfig, batch_size: int, local_rank: int, device: str,
                      uniform_routing: bool) -> tuple[RaggedTensorMetadata, torch.Tensor]:
    expt_dist = make_expt_dict_uniform(c.num_expert_shards, c.num_experts)
    if uniform_routing:
        logits = torch.randn((batch_size, c.num_experts), dtype=torch.float16, device=device)
    else:
        logits = make_prod_like_logits(batch_size, c.num_experts, c.experts_per_token, device)
    sparse_logits = topk(logits, c.experts_per_token, apply_softmax=True)
    expt_hist = sparse_logits.mask_metadata.col_sum
    local_expts = expt_dist[local_rank]
    local_expts_hist = expt_hist[local_expts]
    ragged_metadata = make_ragged_tensor_metadata(local_expts_hist, batch_size * c.experts_per_token)
    ragged_metadata.expected_slice_size = batch_size * c.experts_per_token // c.num_experts
    combine_indx = sparse_logits.mask_metadata.col_sorted_indx
    gather_indx = torch.div(combine_indx, c.experts_per_token, rounding_mode="trunc")
    return ragged_metadata, gather_indx


def prepare_case(c: MLPConfig, batch_size: int, device: str, seed: int = 0, uniform_routing: bool = False,
                 reference: bool = False) -> PreparedCase:
    torch.manual_seed(seed)

    local_rank = int(torch.randint(0, c.num_expert_shards, size=()).item())
    k, n = c.hidden_size, c.intermediate_size
    n_expts_local = c.num_experts // c.num_expert_shards
    ragged_metadata, gather_indx = init_routing_data(c, batch_size, local_rank, device, uniform_routing)
    p = None if reference else select_kernel_config(ragged_metadata.expected_slice_size)
    x = alloc_randn((batch_size, k), dtype=torch.float8_e4m3fn, device=device)
    w, w_scale = alloc_randn_fp4((n_expts_local, k, n), device=device, p=p)
    bias = alloc_randn((n_expts_local, n), dtype=torch.float32, device=device)

    swiglu_alpha = float(torch.rand((), device=device).item()) / 5 + 1.0
    swiglu_limit = float(torch.rand((), device=device).item()) / 5 + 1.3
    fused_activation = FusedActivation(
        FnSpecs("swiglu", swiglu_fn, ("alpha", "limit"), reduction_n=2),
        (swiglu_alpha, swiglu_limit),
    )

    x_scale = (torch.rand((), device=device) + 0.5).reshape(1)
    y_scale = (torch.rand((), device=device) + 3.5).reshape(1)
    return PreparedCase(
        batch_size=batch_size,
        local_rank=local_rank,
        x=x,
        w=w,
        w_scale=w_scale,
        bias=bias,
        ragged_metadata=ragged_metadata,
        gather_indx=gather_indx,
        fused_activation=fused_activation,
        x_scale=x_scale,
        y_scale=y_scale,
        out_shape=(batch_size * c.experts_per_token, n // fused_activation.specs.reduction_n),
        out_dtype=torch.float8_e4m3fn,
    )


def make_precision_config(prepared: PreparedCase) -> PrecisionConfig:
    return PrecisionConfig(
        flexpoint_saturate_inf=True,
        b_mx_scale=prepared.w_scale,
        b_microblock_size=MXFP_BLOCK_SIZE.value,
        out_dtype=prepared.out_dtype,
        flex_ctx=FlexCtx(
            lhs_data=InFlexData(dtype=prepared.out_dtype, scale=prepared.x_scale),
            rhs_data=InFlexData(),
            out_data=OutFlexData(dtype=prepared.out_dtype, expected_scale=prepared.y_scale),
        ),
    )


def make_output_buffer(prepared: PreparedCase) -> torch.Tensor:
    return torch.zeros(prepared.out_shape, dtype=prepared.out_dtype, device=prepared.x.device)


def run_kernel(prepared: PreparedCase, kernel, precision_config: PrecisionConfig, out: torch.Tensor) -> torch.Tensor:
    return kernel(
        a=prepared.x,
        b=prepared.w,
        bias=prepared.bias,
        a_ragged_metadata=prepared.ragged_metadata,
        gather_indx=prepared.gather_indx,
        precision_config=precision_config,
        c=out,
        fused_activation=prepared.fused_activation,
    )


def run_provider(prepared: PreparedCase, provider: str) -> tuple[torch.Tensor, PrecisionConfig]:
    precision_config = make_precision_config(prepared)
    kernel = matmul if provider == "example" else reference_matmul
    y = run_kernel(prepared, kernel, precision_config, make_output_buffer(prepared))
    return y, precision_config


def _storage_nbytes(x: torch.Tensor | Tensor) -> int:
    if isinstance(x, Tensor):
        data = x.storage.data
        return int(data.numel() * data.element_size())
    return int(x.numel() * x.element_size())


def estimate_benchmark_work(c: MLPConfig, prepared: PreparedCase) -> tuple[int, int]:
    slice_sizes = prepared.ragged_metadata.slice_sizes
    n_tokens = int(slice_sizes.sum().item())
    active_slices = int((slice_sizes > 0).sum().item())
    n_slices = prepared.ragged_metadata.n_slices
    k, n = c.hidden_size, c.intermediate_size
    out_n = n // prepared.fused_activation.specs.reduction_n
    active_slice_bytes = active_slices * sum(
        _storage_nbytes(t) // n_slices for t in (prepared.w, prepared.w_scale, prepared.bias))

    flops = 2 * n_tokens * k * n
    nbytes = (n_tokens * k * prepared.x.element_size() + active_slice_bytes + n_tokens * out_n * torch.empty(
        (), dtype=prepared.out_dtype).element_size())
    return flops, nbytes


def benchmark_kernel(prepared: PreparedCase, kernel, flops: int, nbytes: int) -> tuple[float, float]:
    precision_config = make_precision_config(prepared)
    out = make_output_buffer(prepared)
    ms = do_bench_cudagraph(lambda: run_kernel(prepared, kernel, precision_config, out))
    seconds = ms * 1e-3
    return flops * 1e-12 / seconds, nbytes * 1e-12 / seconds


# ===-----------------------------------------------------------------------===#
# Unit Tests
# ===-----------------------------------------------------------------------===#

GPT_OSS_120B_CONFIG = MLPConfig(
    name="gpt-oss-120b",
    num_experts=128,
    experts_per_token=4,
    num_expert_shards=8,
    hidden_size=2880,
    intermediate_size=2 * 2880,
)


def is_blackwell():
    return (triton.runtime.driver.active.get_current_target().backend == "cuda"
            and torch.cuda.get_device_capability()[0] == 10)


@pytest.mark.parametrize("c", [GPT_OSS_120B_CONFIG])
@pytest.mark.parametrize("batch_size", get_batch_sizes(GPT_OSS_120B_CONFIG))
@pytest.mark.skipif(not is_blackwell(), reason="Gluon MoE BMM1 fused-gather is only supported on Blackwell GPUs")
def test_op(c: MLPConfig, batch_size: tuple[int, ...]):
    prepared = prepare_case(c, batch_size, device=f"cuda:{torch.cuda.current_device()}")
    ref_y, ref_precision = run_provider(prepared, "reference")
    cand_y, cand_precision = run_provider(prepared, "example")
    description = f"{c.name}-mm1-bs{prepared.batch_size}"
    assert_close(
        ref_y.to(torch.float32),
        cand_y.to(torch.float32),
        maxtol=0.125,
        rmstol=None,
        description=f"{description}:out",
        verbose=False,
    )
    ref_scale = ref_precision.flex_ctx.out_data.actual_scale
    cand_scale = cand_precision.flex_ctx.out_data.actual_scale
    if ref_scale is not None or cand_scale is not None:
        assert ref_scale is not None and cand_scale is not None
        assert_close(
            ref_scale.to(torch.float32),
            cand_scale.to(torch.float32),
            maxtol=1e-10,
            rmstol=1e-10,
            description=f"{description}:out_scale",
            verbose=False,
        )


# ===-----------------------------------------------------------------------===#
# Benchmarking
# ===-----------------------------------------------------------------------===#

BENCH_TITLE = ("GPT-OSS-120B MoE MM1 "
               f"E={GPT_OSS_120B_CONFIG.num_experts} "
               f"EP={GPT_OSS_120B_CONFIG.experts_per_token} "
               f"ES={GPT_OSS_120B_CONFIG.num_expert_shards} "
               f"B={GPT_OSS_120B_CONFIG.hidden_size}x{GPT_OSS_120B_CONFIG.intermediate_size}")
PEAK_TFLOPS = 5_000.0
PEAK_TBPS = 8.0


def _format_perf(result: tuple[float, float]) -> str:
    tflops, tbps = result
    return f"{tflops:8.2f} TFLOPS ({tflops / PEAK_TFLOPS:6.1%})  {tbps:6.2f} TBPS ({tbps / PEAK_TBPS:6.1%})"


def bench(c: MLPConfig = GPT_OSS_120B_CONFIG, uniform_routing: bool = False):
    batch_sizes = get_batch_sizes(c)
    batch_width = max(len("batch_size"), *(len(str(bs)) for bs in batch_sizes))
    perf_width = max(
        len("reference"),
        len(_format_perf((99999.99, 999.99))),
    )

    print(BENCH_TITLE, flush=True)
    print(f"Peak: {PEAK_TFLOPS / 1000:g} PFLOPS, {PEAK_TBPS:g} TBPS", flush=True)
    print(
        f"{'batch_size':>{batch_width}}  {'example':>{perf_width}}  {'reference':>{perf_width}}",
        flush=True,
    )
    print("-" * (batch_width + 2 + perf_width + 2 + perf_width), flush=True)

    device = f"cuda:{torch.cuda.current_device()}"
    for batch_size in batch_sizes:
        print(f"{batch_size:>{batch_width}}  ", end="", flush=True)

        prepared = prepare_case(
            c,
            batch_size,
            device=device,
            uniform_routing=uniform_routing,
        )
        flops, nbytes = estimate_benchmark_work(c, prepared)
        example = benchmark_kernel(prepared, matmul, flops, nbytes)
        print(f"{_format_perf(example):>{perf_width}}  ", end="", flush=True)

        prepared = prepare_case(
            c,
            batch_size,
            device=device,
            uniform_routing=uniform_routing,
            reference=True,
        )
        flops, nbytes = estimate_benchmark_work(c, prepared)
        reference = benchmark_kernel(prepared, reference_matmul, flops, nbytes)
        print(f"{_format_perf(reference):>{perf_width}}", flush=True)


if __name__ == "__main__":
    bench(uniform_routing=False)