2CTA Block-Scaled Matrix Multiplication
High-performance 2CTA warp-specialized block-scaled MMA. Two CTAs cooperate per output tile, sharing operands to increase arithmetic intensity and reduce the per-CTA SMEM footprint.
import argparse
import itertools
import pytest
import torch
import triton
import triton.experimental.gluon as gluon
import triton.experimental.gluon.language as gl
from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor
from triton._C.libtriton import nvidia
from triton.experimental.gluon.nvidia.blackwell import TensorDescriptor
from triton.experimental.gluon.language.nvidia.blackwell import (
TensorMemoryLayout,
TensorMemoryScalesLayout,
allocate_tensor_memory,
tensor_memory_descriptor,
clc,
tcgen05_copy,
tcgen05_commit,
tcgen05_mma_scaled,
mbarrier,
tma,
)
# ---------------------------------------------------------------------------
# Tile scheduler
# ---------------------------------------------------------------------------
@gluon.jit
def _planar_snake(lin_idx, m_tiles, n_tiles, minor_dim: gl.constexpr, tile_width: gl.constexpr):
major_size = n_tiles if minor_dim == 0 else m_tiles
minor_size = m_tiles if minor_dim == 0 else n_tiles
full_minor_tiles = minor_size // tile_width
full_minor_size = full_minor_tiles * tile_width
full_elements = full_minor_tiles * tile_width * major_size
minor_tile_idx = lin_idx // (tile_width * major_size)
full_minor_within = lin_idx % tile_width
full_major_within = (lin_idx // tile_width) % major_size
full_minor = minor_tile_idx * tile_width + full_minor_within
full_major = gl.where((minor_tile_idx % 2) == 0, full_major_within, major_size - 1 - full_major_within)
partial_width = minor_size - full_minor_size
partial_width = gl.where(partial_width > 0, partial_width, 1)
partial_lin = lin_idx - full_elements
partial_minor_within = partial_lin % partial_width
partial_major_within = (partial_lin // partial_width) % major_size
partial_minor = minor_tile_idx * tile_width + partial_minor_within
partial_major = gl.where((minor_tile_idx % 2) == 0, partial_major_within, major_size - 1 - partial_major_within)
in_full_tile = lin_idx < full_elements
minor = gl.where(in_full_tile, full_minor, partial_minor)
major = gl.where(in_full_tile, full_major, partial_major)
if minor_dim == 0:
return minor, major
return major, minor
def is_blackwell():
if not torch.cuda.is_available():
return False
target = triton.runtime.driver.active.get_current_target()
return target.backend == "cuda" and torch.cuda.get_device_capability()[0] == 10
@gluon.constexpr_function
def get_split_dim(cga_layout, dim):
return 1 << sum(b[dim] != 0 for b in cga_layout)
# ---------------------------------------------------------------------------
# Utilities
# ---------------------------------------------------------------------------
def random_quantized_tensor(MN, K, format):
assert format in ["mxfp4", "mxfp8", "nvfp4"]
VEC_SIZE = 16 if format == "nvfp4" else 32
# Generate a random quantized tensor and its scale factors, assuming we are
# scaling along the K dimension.
base = MXFP4Tensor(size=(MN, K), device="cuda").random()
scale = MXScaleTensor(size=(MN, K // VEC_SIZE), device="cuda").random(low=1 / 128, high=2.0)
# Compute the dequantized tensor to use for testing.
ref = base.to(torch.float32)
scale_ref = scale.to(torch.float32)
value = ref * scale_ref.repeat_interleave(VEC_SIZE, dim=1)
if format == "mxfp8":
# For mxfp8, convert the tensor to a regular float8 torch tensor.
return ref.to(torch.float8_e4m3fn), scale.data, value
elif format == "mxfp4":
# For mxfp4, pack the elements along the K dimension.
return base.to_packed_tensor(dim=1), scale.data, value
else:
# For nvfp4, pack the elements along the K dimension, and convert the
# scale factors to float8_e4m3fn.
return base.to_packed_tensor(dim=1), scale_ref.to(torch.float8_e4m3fn), value
def align_to(a, b):
# Return next multiple of `b` greater than or equal to `a`.
return triton.cdiv(a, b) * b
def swizzle_scales_packed_block(scales: torch.Tensor):
# When the scale tensor is not an even multiple of [128, 4], we need to pad
# the scale tensor so it can use the packed block format.
PAD_MN = align_to(scales.shape[0], 128) - scales.shape[0]
PAD_K = align_to(scales.shape[1], 4) - scales.shape[1]
scales = torch.nn.functional.pad(scales, (0, PAD_K, 0, PAD_MN))
MN, SCALE_K = scales.shape[0], scales.shape[1]
REP_MN = MN // 128
REP_K = SCALE_K // 4
scales = scales.reshape(REP_MN, 4, 32, REP_K, 4)
scales = scales.permute(0, 3, 2, 1, 4)
return scales.contiguous()
# ---------------------------------------------------------------------------
# Autotuning configs and hook
# ---------------------------------------------------------------------------
def mma_scaled_get_configs(pre_hook=None, cga_layouts=None):
if cga_layouts is None:
cga_layouts = [(), ((1, 0), )]
return [
triton.Config(
{
"BLOCK_M": BM,
"BLOCK_N": BN,
"BLOCK_K": BK,
"EPILOGUE_BLOCK_N": epilogue_n,
"num_buffers": stages,
"num_acc_buffers": acc_buffers,
"GRID_MINOR_DIM": minor_dim,
"GRID_TILE_WIDTH": grid_tile_width,
"CGA_LAYOUT": cga_layout,
},
num_warps=4,
num_ctas=2**len(cga_layout),
pre_hook=pre_hook,
)
for BM in (128, 256)
for BN in (128, 256)
for BK in (128, 256)
for epilogue_n in (64, BN)
for minor_dim in (0, 1)
for grid_tile_width in (4, 8, 16)
for stages in (3, 4, 5)
for acc_buffers in (1, 2)
for cga_layout in cga_layouts
# tcgen05_mma_scaled requires BLOCK_M_PER_CTA == 128
if BM // (2**len(cga_layout)) == 128 if epilogue_n <= BN
]
def mma_scaled_tma_set_block_size_hook(nargs):
block_m = nargs["BLOCK_M"]
block_n = nargs["BLOCK_N"]
block_k = nargs["BLOCK_K"]
epilogue_n = nargs["EPILOGUE_BLOCK_N"]
cga_layout = nargs["CGA_LAYOUT"]
a_base = nargs["a_desc"].base
b_base = nargs["b_desc"].base
a_is_fp4 = a_base.dtype == torch.uint8
b_is_fp4 = b_base.dtype == torch.uint8
mixed_prec = a_is_fp4 != b_is_fp4
a_elem_per_byte = 2 if a_is_fp4 else 1
b_elem_per_byte = 2 if b_is_fp4 else 1
a_block = [block_m, block_k // a_elem_per_byte]
b_block = [block_n, block_k // b_elem_per_byte]
c_block = [block_m, epilogue_n]
nargs["a_desc"].block_shape = a_block
nargs["b_desc"].block_shape = b_block
nargs["c_desc"].block_shape = c_block
cga = tuple(tuple(x) for x in cga_layout) if cga_layout else None
nargs["a_desc"].layout = gl.NVMMASharedLayout.get_default_for(a_block, gl.uint8 if a_is_fp4 else gl.float8e4nv,
fp4_padded=(mixed_prec and a_is_fp4), cga_layout=cga)
nargs["b_desc"].layout = gl.NVMMASharedLayout.get_default_for(b_block, gl.uint8 if b_is_fp4 else gl.float8e4nv,
fp4_padded=(mixed_prec and b_is_fp4), cga_layout=cga)
c_dtype = getattr(gl, str(nargs["c_desc"].base.dtype).split('.')[1])
nargs["c_desc"].layout = gl.NVMMASharedLayout.get_default_for(c_block, c_dtype, cga_layout=cga)
a_scale_base = nargs["a_scale_desc"].base
is_nvfp4 = a_scale_base.dtype == torch.float8_e4m3fn
vec_size = 16 if is_nvfp4 else 32
rep_m = block_m // 128
rep_n = block_n // 128
rep_k = block_k // (vec_size * 4)
nargs["a_scale_desc"].block_shape = [1, rep_m, rep_k, 2, 256]
nargs["b_scale_desc"].block_shape = [1, rep_n, rep_k, 2, 256]
if cga_layout:
cga_a_scale = [[0, 1, 0, 0, 0]]
cga_b_scale = [[0, 0, 0, 0, 0]]
nargs["a_scale_desc"].layout = gl.NVMMASharedLayout(swizzle_byte_width=0, element_bitwidth=8, rank=5,
cga_layout=cga_a_scale)
nargs["b_scale_desc"].layout = gl.NVMMASharedLayout(swizzle_byte_width=0, element_bitwidth=8, rank=5,
cga_layout=cga_b_scale)
else:
no_swizzle = gl.NVMMASharedLayout(swizzle_byte_width=0, element_bitwidth=8, rank=5)
nargs["a_scale_desc"].layout = no_swizzle
nargs["b_scale_desc"].layout = no_swizzle
@gluon.jit
def unswizzle_scales_shared_memory(smem, BLOCK_MN: gl.constexpr, BLOCK_K: gl.constexpr, VEC_SIZE: gl.constexpr):
smem = smem.reshape((smem.shape[1], smem.shape[2], 32, 4, 4))
smem = smem.permute((0, 3, 2, 1, 4))
return smem.reshape((BLOCK_MN, BLOCK_K // VEC_SIZE))
@gluon.jit
def async_mma_scaled_impl(a_smem, b_smem, a_scale_smem, b_scale_smem, acc_tmem, use_acc, pred):
A_ELEM_PER_BYTE: gl.constexpr = 2 if a_smem.dtype == gl.uint8 else 1
BLOCK_M: gl.constexpr = a_smem.shape[0]
BLOCK_N: gl.constexpr = b_smem.shape[0]
BLOCK_K: gl.constexpr = a_smem.shape[1] * A_ELEM_PER_BYTE
# Recall we use `uint8` to represent fp4 elements.
VEC_SIZE: gl.constexpr = 32 if a_scale_smem.dtype == gl.uint8 else 16
a_scale = unswizzle_scales_shared_memory(a_scale_smem, BLOCK_M, BLOCK_K, VEC_SIZE)
b_scale = unswizzle_scales_shared_memory(b_scale_smem, BLOCK_N, BLOCK_K, VEC_SIZE)
# We don't need to hoist the scales tensor memory allocations outside of the loop,
# so we can pull them into this helper function.
two_ctas: gl.constexpr = acc_tmem.type.layout.two_ctas
a_scale_layout: gl.constexpr = TensorMemoryScalesLayout(cga_layout=[[1, 0]] if two_ctas else [])
b_scale_layout: gl.constexpr = TensorMemoryScalesLayout(cga_layout=[[0, 0]] if two_ctas else [])
a_scale_tmem = allocate_tensor_memory(a_scale.dtype, a_scale.type.shape, a_scale_layout)
b_scale_tmem = allocate_tensor_memory(b_scale.dtype, b_scale.type.shape, b_scale_layout)
tcgen05_copy(a_scale, a_scale_tmem)
tcgen05_copy(b_scale, b_scale_tmem)
a_format: gl.constexpr = "e2m1" if a_smem.dtype == gl.uint8 else "e4m3"
b_format: gl.constexpr = "e2m1" if b_smem.dtype == gl.uint8 else "e4m3"
tcgen05_mma_scaled(a_smem, b_smem.permute((1, 0)), acc_tmem, a_scale_tmem, b_scale_tmem, a_format, b_format,
use_acc=use_acc, pred=pred)
# This helper function computes all the load indexing and issues the async loads
# based on the current `pid_m`, `pid_n`, and `k` indices. The compiler will run
# loop-invariant code motion to hoist code that does not depend on `k`, like
# `pid_m * BLOCK_M`, outside of the inner loop, so we can safely abstract the
# load indexing without performance loss.
#
# Encapsulating the load indexing logic will help keep our pipelined kernel code
# clean, as pipelining can get messy.
@gluon.jit
def issue_loads(producer, pid_m, pid_n, k, a_desc, b_desc, a_scale_desc, b_scale_desc, a_bufs, b_bufs, a_scale_bufs,
b_scale_bufs, bars, pred, multicast_b_scale: gl.constexpr = False):
A_ELEM_PER_BYTE: gl.constexpr = 2 if a_desc.dtype == gl.uint8 else 1
B_ELEM_PER_BYTE: gl.constexpr = 2 if b_desc.dtype == gl.uint8 else 1
BLOCK_M: gl.constexpr = a_desc.block_shape[0]
BLOCK_N: gl.constexpr = b_desc.block_shape[0]
BLOCK_K: gl.constexpr = a_desc.block_shape[1] * A_ELEM_PER_BYTE
REP_M: gl.constexpr = a_scale_desc.block_shape[1]
REP_N: gl.constexpr = b_scale_desc.block_shape[1]
A_REP_K: gl.constexpr = a_scale_desc.block_shape[2]
B_REP_K: gl.constexpr = b_scale_desc.block_shape[2]
off_m = pid_m * BLOCK_M
off_n = pid_n * BLOCK_N
off_m_a_scale = pid_m * REP_M
off_n_b_scale = pid_n * REP_N
off_k_a = k // A_ELEM_PER_BYTE
off_k_b = k // B_ELEM_PER_BYTE
off_k_a_scale = (k // BLOCK_K) * A_REP_K
off_k_b_scale = (k // BLOCK_K) * B_REP_K
index = producer.index
bar = bars.index(index)
mbarrier.expect(
bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta + a_scale_desc.nbytes_per_cta + b_scale_desc.nbytes_per_cta,
pred)
tma.async_load(a_desc, [off_m, off_k_a], bar, a_bufs.index(index), pred)
tma.async_load(b_desc, [off_n, off_k_b], bar, b_bufs.index(index), pred)
tma.async_load(a_scale_desc, [0, off_m_a_scale, off_k_a_scale, 0, 0], bar, a_scale_bufs.index(index), pred)
tma.async_load(b_scale_desc, [0, off_n_b_scale, off_k_b_scale, 0, 0], bar, b_scale_bufs.index(index), pred,
multicast=multicast_b_scale)
return producer.next(pred)
@gluon.jit
def issue_mma(consumer, c_bars, a_bufs, b_bufs, a_scale_bufs, b_scale_bufs, producer, p_bars, acc_tmem, use_acc, pred):
c_index = consumer.index
mbarrier.wait(c_bars.index(c_index), consumer.phase, pred)
async_mma_scaled_impl(a_bufs.index(c_index), b_bufs.index(c_index), a_scale_bufs.index(c_index),
b_scale_bufs.index(c_index), acc_tmem, use_acc, pred)
tcgen05_commit(p_bars.index(producer.index), pred)
return consumer.next(pred), producer.next(pred)
@gluon.aggregate
class Counter:
index: gl.tensor
phase: gl.tensor
num_barriers: gl.constexpr
@gluon.jit
def create(phase, num_barriers: gl.constexpr):
return Counter(gl.to_tensor(0), gl.to_tensor(phase), num_barriers)
@gluon.must_use_result
@gluon.jit
def next(self, pred=True):
incr = self.index + gl.where(pred, 1, 0)
rollover = incr == self.num_barriers
index = gl.where(rollover, 0, incr)
phase = gl.where(rollover, self.phase ^ 1, self.phase)
return Counter(index, phase, self.num_barriers)
@gluon.aggregate
class ClcTileSchedulerConsumer:
has_work: gl.tensor
tile_id: gl.tensor
pid_m: gl.tensor
pid_n: gl.tensor
num_pid_m: gl.tensor
num_pid_n: gl.tensor
TILE_M: gl.constexpr
TILE_N: gl.constexpr
MINOR_DIM: gl.constexpr
GRID_TILE_WIDTH: gl.constexpr
clc_result_buffers: gl.shared_memory_descriptor
clc_barriers: gl.shared_memory_descriptor
clc_planar_pid_buffers: gl.shared_memory_descriptor
clc_planar_ready_bars: gl.shared_memory_descriptor
clc_consumed_bars: gl.shared_memory_descriptor
counter: Counter
consumed_counter: Counter
@gluon.jit
def initialize(M, N, TILE_M: gl.constexpr, TILE_N: gl.constexpr, MINOR_DIM: gl.constexpr,
GRID_TILE_WIDTH: gl.constexpr, clc_result_buffers, clc_barriers, clc_planar_pid_buffers,
clc_planar_ready_bars, clc_consumed_bars):
tile_id = gl.program_id(axis=0)
num_pid_m = gl.cdiv(M, TILE_M)
num_pid_n = gl.cdiv(N, TILE_N)
pid_m, pid_n = _planar_snake(tile_id, num_pid_m, num_pid_n, MINOR_DIM, GRID_TILE_WIDTH)
has_work = gl.to_tensor(True)
counter = Counter.create(0, clc_barriers.shape[0])
consumed_counter = Counter.create(0, clc_barriers.shape[0])
return ClcTileSchedulerConsumer(
has_work,
tile_id,
pid_m,
pid_n,
num_pid_m,
num_pid_n,
TILE_M,
TILE_N,
MINOR_DIM,
GRID_TILE_WIDTH,
clc_result_buffers,
clc_barriers,
clc_planar_pid_buffers,
clc_planar_ready_bars,
clc_consumed_bars,
counter,
consumed_counter,
)
@gluon.jit
def get_offsets(self):
return self.pid_m * self.TILE_M, self.pid_n * self.TILE_N
@gluon.jit
def step(self, iteration):
# The 0-th iteration uses the program_id as the tile_id.
# At the end of each iteration we prefetch the next tile.
# As such we must signal the consumed slot at the end of
# each iteration skipping the first one.
consumed_counter = self.consumed_counter
if iteration > 0:
mbarrier.arrive(self.clc_consumed_bars.index(consumed_counter.index))
consumed_counter = consumed_counter.next()
counter = self.counter
barrier = self.clc_barriers.index(counter.index)
result = self.clc_result_buffers.index(counter.index)
mbarrier.wait(barrier, counter.phase)
clc_res = clc.load_result(result)
mbarrier.wait(self.clc_planar_ready_bars.index(counter.index), counter.phase)
planar_slot = self.clc_planar_pid_buffers.index(counter.index)
planar_layout: gl.constexpr = gl.BlockedLayout([1], [32], [gl.num_warps()], [0],
[[0]] * (gl.num_ctas().bit_length() - 1))
packed_pid = planar_slot.load(planar_layout).reshape([])
pid_m = ((packed_pid >> 32) & 0xFFFFFFFF).to(gl.int32)
pid_n = (packed_pid & 0xFFFFFFFF).to(gl.int32)
has_work = clc_res.is_canceled()
tile_id = self.tile_id
if has_work:
tile_id = clc_res.program_id(0)
return ClcTileSchedulerConsumer(
has_work,
tile_id,
pid_m,
pid_n,
self.num_pid_m,
self.num_pid_n,
self.TILE_M,
self.TILE_N,
self.MINOR_DIM,
self.GRID_TILE_WIDTH,
self.clc_result_buffers,
self.clc_barriers,
self.clc_planar_pid_buffers,
self.clc_planar_ready_bars,
self.clc_consumed_bars,
counter.next(),
consumed_counter,
)
# ---------------------------------------------------------------------------
# Partitions
# ---------------------------------------------------------------------------
@gluon.aggregate
class PartitionArgs:
a_desc: tma.tensor_descriptor
b_desc: tma.tensor_descriptor
c_desc: tma.tensor_descriptor
a_scale_desc: tma.tensor_descriptor
b_scale_desc: tma.tensor_descriptor
a_bufs: gl.shared_memory_descriptor
b_bufs: gl.shared_memory_descriptor
a_scale_bufs: gl.shared_memory_descriptor
b_scale_bufs: gl.shared_memory_descriptor
load_empty_bars: gl.shared_memory_descriptor
load_ready_bars: gl.shared_memory_descriptor
acc_bufs: tensor_memory_descriptor
acc_empty_bars: gl.shared_memory_descriptor
acc_ready_bars: gl.shared_memory_descriptor
clc_result_buffers: gl.shared_memory_descriptor
clc_barriers: gl.shared_memory_descriptor
clc_planar_pid_buffers: gl.shared_memory_descriptor
clc_planar_ready_bars: gl.shared_memory_descriptor
clc_consumed_bars: gl.shared_memory_descriptor
MINOR_DIM: gl.constexpr
GRID_TILE_WIDTH: gl.constexpr
@gluon.jit
def get_clc_consumer(self):
return ClcTileSchedulerConsumer.initialize(
self.c_desc.shape[0],
self.c_desc.shape[1],
self.a_desc.block_shape[0],
self.b_desc.block_shape[0],
self.MINOR_DIM,
self.GRID_TILE_WIDTH,
self.clc_result_buffers,
self.clc_barriers,
self.clc_planar_pid_buffers,
self.clc_planar_ready_bars,
self.clc_consumed_bars,
)
@gluon.jit
def mma_scaled_load_partition(p):
A_ELEM_PER_BYTE: gl.constexpr = 2 if p.a_desc.dtype == gl.uint8 else 1
BLOCK_K: gl.constexpr = p.a_desc.block_shape[1] * A_ELEM_PER_BYTE
K = p.a_desc.shape[1] * A_ELEM_PER_BYTE
state = Counter.create(1, p.load_empty_bars.shape[0])
scheduler = p.get_clc_consumer()
i = 0
while scheduler.has_work:
for k in range(0, K, BLOCK_K):
mbarrier.wait(p.load_empty_bars.index(state.index), state.phase)
state = issue_loads(state, scheduler.pid_m, scheduler.pid_n, k, p.a_desc, p.b_desc, p.a_scale_desc,
p.b_scale_desc, p.a_bufs, p.b_bufs, p.a_scale_bufs, p.b_scale_bufs, p.load_ready_bars,
pred=True, multicast_b_scale=gl.num_ctas() > 1)
scheduler = scheduler.step(i)
i += 1
@gluon.jit
def mma_scaled_mma_partition(p):
A_ELEM_PER_BYTE: gl.constexpr = 2 if p.a_desc.dtype == gl.uint8 else 1
BLOCK_K: gl.constexpr = p.a_desc.block_shape[1] * A_ELEM_PER_BYTE
K = p.a_desc.shape[1] * A_ELEM_PER_BYTE
load_state = Counter.create(0, p.load_empty_bars.shape[0])
acc_state = Counter.create(1, p.acc_empty_bars.shape[0])
scheduler = p.get_clc_consumer()
i = 0
while scheduler.has_work:
mbarrier.wait(p.acc_empty_bars.index(acc_state.index), acc_state.phase)
acc_buf = p.acc_bufs.index(acc_state.index)
use_acc = False
for k in range(0, K, BLOCK_K):
_, load_state = issue_mma(load_state, p.load_ready_bars, p.a_bufs, p.b_bufs, p.a_scale_bufs, p.b_scale_bufs,
load_state, p.load_empty_bars, acc_buf, use_acc, pred=True)
use_acc = True
tcgen05_commit(p.acc_ready_bars.index(acc_state.index))
acc_state = acc_state.next()
scheduler = scheduler.step(i)
i += 1
@gluon.jit
def mma_scaled_epilogue_partition(p):
tile_m: gl.constexpr = p.c_desc.block_shape[0]
BLOCK_N: gl.constexpr = p.b_desc.block_shape[0]
EPILOGUE_BLOCK_N: gl.constexpr = p.c_desc.block_shape[1]
subtile_factor: gl.constexpr = BLOCK_N // EPILOGUE_BLOCK_N
subtile_stages: gl.constexpr = 1 if subtile_factor == 1 else 2
acc_state = Counter.create(0, p.acc_empty_bars.shape[0])
acc_smems = gl.allocate_shared_memory(p.c_desc.dtype, [subtile_stages, tile_m, EPILOGUE_BLOCK_N], p.c_desc.layout)
sub_acc_state = Counter.create(0, subtile_stages)
scheduler = p.get_clc_consumer()
i = 0
while scheduler.has_work:
off_m, off_n = scheduler.get_offsets()
mbarrier.wait(p.acc_ready_bars.index(acc_state.index), acc_state.phase)
acc_buf = p.acc_bufs.index(acc_state.index)
for s in gl.static_range(subtile_factor):
acc_sub = acc_buf.slice(EPILOGUE_BLOCK_N * s, EPILOGUE_BLOCK_N)
acc_smem = acc_smems.index(sub_acc_state.index)
acc = acc_sub.load().to(p.c_desc.dtype)
tma.store_wait(pendings=subtile_stages - 1)
acc_smem.store(acc)
tma.async_copy_shared_to_global(p.c_desc, [off_m, off_n + EPILOGUE_BLOCK_N * s], acc_smem)
sub_acc_state = sub_acc_state.next()
mbarrier.arrive(p.acc_empty_bars.index(acc_state.index), count=1)
acc_state = acc_state.next()
scheduler = scheduler.step(i)
i += 1
tma.store_wait(0)
@gluon.jit
def mma_scaled_clc_partition(p):
TILE_M: gl.constexpr = p.a_desc.block_shape[0]
TILE_N: gl.constexpr = p.b_desc.block_shape[0]
has_work = gl.to_tensor(True)
num_pid_m = gl.cdiv(p.c_desc.shape[0], TILE_M)
num_pid_n = gl.cdiv(p.c_desc.shape[1], TILE_N)
state = Counter.create(0, p.clc_barriers.shape[0])
consumed_state = Counter.create(1, p.clc_barriers.shape[0])
ACC_STAGES: gl.constexpr = p.clc_barriers.shape[0]
i = 0
while has_work:
# Reuse the slot only after all consumer partitions signaled consumed.
mbarrier.wait(p.clc_consumed_bars.index(consumed_state.index), consumed_state.phase, pred=(i >= ACC_STAGES))
barrier = p.clc_barriers.index(state.index)
result = p.clc_result_buffers.index(state.index)
# 16: clc.try_cancel has a `.b128` modifier
mbarrier.expect(barrier, 16)
clc.try_cancel(result, barrier)
mbarrier.wait(barrier, state.phase)
clc_res = clc.load_result(result)
has_work = clc_res.is_canceled()
pid_m = gl.to_tensor(0)
pid_n = gl.to_tensor(0)
if has_work:
tile_id = clc_res.program_id(0)
pid_m, pid_n = _planar_snake(tile_id, num_pid_m, num_pid_n, p.MINOR_DIM, p.GRID_TILE_WIDTH)
packed_pid = (pid_m.to(gl.int64) << 32) | (pid_n.to(gl.int64) & 0xFFFFFFFF)
planar_slot = p.clc_planar_pid_buffers.index(state.index)
planar_layout: gl.constexpr = gl.BlockedLayout([1], [32], [gl.num_warps()], [0],
[[0]] * (gl.num_ctas().bit_length() - 1))
planar_slot.store(gl.full([1], packed_pid, gl.int64, layout=planar_layout))
mbarrier.arrive(p.clc_planar_ready_bars.index(state.index))
state = state.next()
consumed_state = consumed_state.next()
i += 1
# ---------------------------------------------------------------------------
# Kernel
# ---------------------------------------------------------------------------
@gluon.jit
def mma_scaled_warp_specialized_kernel(a_desc, b_desc, c_desc, a_scale_desc, b_scale_desc, M, N, K, A_ELEM_PER_BYTE,
num_buffers: gl.constexpr, BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr,
BLOCK_K: gl.constexpr, EPILOGUE_BLOCK_N: gl.constexpr,
num_acc_buffers: gl.constexpr, GRID_MINOR_DIM: gl.constexpr,
GRID_TILE_WIDTH: gl.constexpr, CGA_LAYOUT: gl.constexpr):
NUM_CTAS: gl.constexpr = gl.num_ctas()
TWO_CTAS: gl.constexpr = NUM_CTAS > 1
BLOCK_M_PER_CTA: gl.constexpr = BLOCK_M // NUM_CTAS
gl.static_assert(BLOCK_M_PER_CTA == 64 or BLOCK_M_PER_CTA == 128)
N_PARTITIONS: gl.constexpr = 4
a_bufs = gl.allocate_shared_memory(a_desc.dtype, [num_buffers] + a_desc.block_shape, a_desc.layout)
b_bufs = gl.allocate_shared_memory(b_desc.dtype, [num_buffers] + b_desc.block_shape, b_desc.layout)
a_scale_bufs = gl.allocate_shared_memory(a_scale_desc.dtype, [num_buffers] + a_scale_desc.block_shape,
a_scale_desc.layout)
b_scale_bufs = gl.allocate_shared_memory(b_scale_desc.dtype, [num_buffers] + b_scale_desc.block_shape,
b_scale_desc.layout)
tmem_layout: gl.constexpr = TensorMemoryLayout([BLOCK_M_PER_CTA, BLOCK_N], col_stride=1, cga_layout=CGA_LAYOUT,
two_ctas=TWO_CTAS)
load_empty_bars = mbarrier.allocate_mbarrier(batch=num_buffers)
load_ready_bars = mbarrier.allocate_mbarrier(batch=num_buffers, two_ctas=TWO_CTAS)
for i in gl.static_range(num_buffers):
mbarrier.init(load_empty_bars.index(i), count=1)
mbarrier.init(load_ready_bars.index(i), count=1)
acc_empty_bars = mbarrier.allocate_mbarrier(batch=num_acc_buffers, two_ctas=TWO_CTAS)
acc_ready_bars = mbarrier.allocate_mbarrier(batch=num_acc_buffers)
for i in gl.static_range(num_acc_buffers):
mbarrier.init(acc_empty_bars.index(i), count=1)
mbarrier.init(acc_ready_bars.index(i), count=1)
clc_barriers = mbarrier.allocate_mbarrier(batch=num_acc_buffers)
clc_planar_ready_bars = mbarrier.allocate_mbarrier(batch=num_acc_buffers)
clc_consumed_bars = mbarrier.allocate_mbarrier(batch=num_acc_buffers, two_ctas=TWO_CTAS)
for i in gl.static_range(num_acc_buffers):
mbarrier.init(clc_barriers.index(i), count=1)
mbarrier.init(clc_planar_ready_bars.index(i), count=1)
mbarrier.init(clc_consumed_bars.index(i), count=N_PARTITIONS - 1)
cga_layout_clc: gl.constexpr = [[0]] * (gl.num_ctas().bit_length() - 1)
clc_layout: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, [0], cga_layout=cga_layout_clc)
clc_result_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 2], clc_layout)
clc_planar_pid_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 1], clc_layout)
acc_bufs = allocate_tensor_memory(gl.float32, [num_acc_buffers, BLOCK_M, BLOCK_N], tmem_layout)
p = PartitionArgs(a_desc, b_desc, c_desc, a_scale_desc, b_scale_desc, a_bufs, b_bufs, a_scale_bufs, b_scale_bufs,
load_empty_bars, load_ready_bars, acc_bufs, acc_empty_bars, acc_ready_bars, clc_result_buffers,
clc_barriers, clc_planar_pid_buffers, clc_planar_ready_bars, clc_consumed_bars, GRID_MINOR_DIM,
GRID_TILE_WIDTH)
gl.warp_specialize([
(mma_scaled_epilogue_partition, (p, )),
(mma_scaled_mma_partition, (p, )),
(mma_scaled_load_partition, (p, )),
(mma_scaled_clc_partition, (p, )),
], [1, 1, 1], [24, 24, 24])
mma_scaled_kernel = triton.autotune(
configs=mma_scaled_get_configs(pre_hook=mma_scaled_tma_set_block_size_hook),
key=["M", "N", "K", "A_ELEM_PER_BYTE"],
)(mma_scaled_warp_specialized_kernel)
mma_scaled_1cta_kernel = triton.autotune(
configs=mma_scaled_get_configs(pre_hook=mma_scaled_tma_set_block_size_hook, cga_layouts=[()]),
key=["M", "N", "K", "A_ELEM_PER_BYTE"],
)(mma_scaled_warp_specialized_kernel)
mma_scaled_2cta_kernel = triton.autotune(
configs=mma_scaled_get_configs(pre_hook=mma_scaled_tma_set_block_size_hook, cga_layouts=[((1, 0), )]),
key=["M", "N", "K", "A_ELEM_PER_BYTE"],
)(mma_scaled_warp_specialized_kernel)
# ---------------------------------------------------------------------------
# Wrapper
# ---------------------------------------------------------------------------
def make_dummy_descriptors(A, B, A_scale, B_scale, out_dtype, M, N):
"""Create TMA descriptors with dummy block shapes; the hook sets the real ones."""
dummy_block_2d = [1, 1]
dummy_layout_2d = gl.NVMMASharedLayout.get_default_for(dummy_block_2d, gl.float8e4nv)
a_desc = TensorDescriptor.from_tensor(A, dummy_block_2d, dummy_layout_2d)
b_desc = TensorDescriptor.from_tensor(B, dummy_block_2d, dummy_layout_2d)
C = torch.empty(M, N, device="cuda", dtype=out_dtype)
C_dtype = getattr(gl, str(out_dtype).split('.')[1])
c_layout = gl.NVMMASharedLayout.get_default_for(dummy_block_2d, C_dtype)
c_desc = TensorDescriptor.from_tensor(C, dummy_block_2d, c_layout)
A_scale_5d = A_scale.reshape(1, A_scale.shape[0], A_scale.shape[1], 2, 256)
B_scale_5d = B_scale.reshape(1, B_scale.shape[0], B_scale.shape[1], 2, 256)
dummy_block_5d = [1, 1, 1, 2, 256]
dummy_layout_5d = gl.NVMMASharedLayout(swizzle_byte_width=0, element_bitwidth=8, rank=5)
a_scale_desc = TensorDescriptor.from_tensor(A_scale_5d, dummy_block_5d, dummy_layout_5d)
b_scale_desc = TensorDescriptor.from_tensor(B_scale_5d, dummy_block_5d, dummy_layout_5d)
return a_desc, b_desc, c_desc, a_scale_desc, b_scale_desc
def mma_scaled_warp_specialized(A, B, A_scale, B_scale, VEC_SIZE, GRID_MINOR_DIM=0, GRID_TILE_WIDTH=4,
out_dtype=torch.float16, BLOCK_M=128, BLOCK_N=256, BLOCK_K=None, EPILOGUE_BLOCK_N=None,
num_buffers=3, acc_buffers=None, num_ctas=1):
"""Warp-specialized block-scale MMA (supports 1CTA and 2CTA)."""
if BLOCK_K is None:
BLOCK_K = 128 if torch.float8_e4m3fn in [A.dtype, B.dtype] else 256
if EPILOGUE_BLOCK_N is None:
EPILOGUE_BLOCK_N = BLOCK_N
if acc_buffers is None:
acc_buffers = 2 if BLOCK_N < 256 else 1
M, N = A.shape[0], B.shape[0]
IS_FP4_A = A.dtype == torch.uint8
K = A.shape[1] * (2 if IS_FP4_A else 1)
cga_layout = ((1, 0), ) if num_ctas > 1 else ()
A_desc, B_desc, C_desc, A_scale_desc, B_scale_desc = make_dummy_descriptors(A, B, A_scale, B_scale, out_dtype, M, N)
mma_scaled_tma_set_block_size_hook({
"a_desc": A_desc,
"b_desc": B_desc,
"c_desc": C_desc,
"a_scale_desc": A_scale_desc,
"b_scale_desc": B_scale_desc,
"BLOCK_M": BLOCK_M,
"BLOCK_N": BLOCK_N,
"BLOCK_K": BLOCK_K,
"EPILOGUE_BLOCK_N": EPILOGUE_BLOCK_N,
"CGA_LAYOUT": cga_layout,
})
num_pid = triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)
grid = (num_pid, )
A_ELEM_PER_BYTE = 2 if IS_FP4_A else 1
mma_scaled_warp_specialized_kernel[grid](
A_desc,
B_desc,
C_desc,
A_scale_desc,
B_scale_desc,
M,
N,
K,
A_ELEM_PER_BYTE,
num_buffers,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EPILOGUE_BLOCK_N,
acc_buffers,
GRID_MINOR_DIM,
GRID_TILE_WIDTH,
cga_layout,
num_ctas=num_ctas,
)
return C_desc.base
def mma_scaled_matmul(A, B, A_scale, B_scale, VEC_SIZE, out_dtype=torch.float16, num_ctas=None):
"""Autotuned block-scaled matmul.
Args:
num_ctas: None = autotune across all configs (1CTA and 2CTA),
1 = autotune 1CTA configs only,
2 = autotune 2CTA configs only.
"""
M, N = A.shape[0], B.shape[0]
IS_FP4_A = A.dtype == torch.uint8
A_ELEM_PER_BYTE = 2 if IS_FP4_A else 1
K = A.shape[1] * A_ELEM_PER_BYTE
a_desc, b_desc, c_desc, a_scale_desc, b_scale_desc = make_dummy_descriptors(A, B, A_scale, B_scale, out_dtype, M, N)
def grid(meta):
num_tiles = triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"])
return (num_tiles, )
kernel = {None: mma_scaled_kernel, 1: mma_scaled_1cta_kernel, 2: mma_scaled_2cta_kernel}[num_ctas]
kernel[grid](a_desc, b_desc, c_desc, a_scale_desc, b_scale_desc, M, N, K, A_ELEM_PER_BYTE)
return c_desc.base
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("K", [128, 640, 704, 1152, 4096])
@pytest.mark.parametrize("M, N", [(2048, 2048), (500, 600), (256, 256), (128, 128), (8192, 8192)])
@pytest.mark.parametrize("a_format, b_format",
list(itertools.product(["mxfp8", "mxfp4"], repeat=2)) + [("nvfp4", "nvfp4")])
@pytest.mark.parametrize("num_ctas, BLOCK_N, EPILOGUE_BLOCK_N, num_buffers", [
(2, 256, 256, 4),
(2, 256, 64, 5),
(2, 128, 64, 6),
(1, 256, 256, 3),
(1, 256, 64, 3),
(1, 128, 64, 5),
])
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_mma_scaled_warp_specialized(M, N, K, a_format, b_format, num_ctas, BLOCK_N, EPILOGUE_BLOCK_N, num_buffers):
if a_format != b_format and K % 128 != 0:
pytest.skip("fp4 packed tensor descriptor requires K to be a multiple of 128")
BLOCK_M = 256 if num_ctas > 1 else 128
torch.manual_seed(0)
A, A_scale, A_ref = random_quantized_tensor(M, K, a_format)
B, B_scale, B_ref = random_quantized_tensor(N, K, b_format)
VEC_SIZE = 16 if a_format == "nvfp4" else 32
A_scale = swizzle_scales_packed_block(A_scale)
B_scale = swizzle_scales_packed_block(B_scale)
C_ref = A_ref @ B_ref.T
C = mma_scaled_warp_specialized(A, B, A_scale, B_scale, VEC_SIZE, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
EPILOGUE_BLOCK_N=EPILOGUE_BLOCK_N, num_buffers=num_buffers, num_ctas=num_ctas)
torch.testing.assert_close(C_ref, C.to(torch.float32), atol=1e-3, rtol=1e-3)
# ---------------------------------------------------------------------------
# Benchmark
# ---------------------------------------------------------------------------
if is_blackwell():
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
cublas = nvidia.cublas.CublasLt(cublas_workspace)
else:
cublas = None
CUBLAS_FORMATS = {"mxfp8", "nvfp4"}
def cublas_block_scaled_matmul(A, B, A_scale_flat, B_scale_flat, fmt):
"""cuBLAS block-scaled matmul. Supports mxfp8 and nvfp4 (mxfp4 not supported by cuBLAS)."""
M, N = A.shape[0], B.shape[0]
output = torch.empty((M, N), dtype=torch.float16, device="cuda")
if fmt == "mxfp8":
cublas.block_scaled_matmul_mxfp8(A, B, output, A_scale_flat, B_scale_flat)
elif fmt == "nvfp4":
cublas.block_scaled_matmul_nvfp4(A, B, output, A_scale_flat, B_scale_flat)
else:
raise ValueError(f"cuBLAS does not support format: {fmt}")
return output
ALL_FORMATS = [("mxfp8", "mxfp8"), ("nvfp4", "nvfp4"), ("mxfp8", "mxfp4"), ("mxfp4", "mxfp4")]
MNK_VALS = [8192, 16384, 32768]
BEST_1CTA_CONFIG = dict(BLOCK_M=128, BLOCK_N=256, EPILOGUE_BLOCK_N=64, num_buffers=3, num_ctas=1, GRID_MINOR_DIM=1,
GRID_TILE_WIDTH=8)
BEST_2CTA_CONFIG = dict(BLOCK_M=256, BLOCK_N=256, EPILOGUE_BLOCK_N=64, num_buffers=5, num_ctas=2, GRID_MINOR_DIM=0,
GRID_TILE_WIDTH=8)
def make_fn(variant, A, B, A_scale, B_scale, VEC_SIZE, a_format, use_autotuned=False):
"""Build the callable for a given variant (1cta, 2cta, or cublas)."""
if variant == "2cta":
if use_autotuned:
return lambda: mma_scaled_matmul(A, B, A_scale, B_scale, VEC_SIZE, num_ctas=2)
return lambda: mma_scaled_warp_specialized(A, B, A_scale, B_scale, VEC_SIZE, **BEST_2CTA_CONFIG)
elif variant == "1cta":
if use_autotuned:
return lambda: mma_scaled_matmul(A, B, A_scale, B_scale, VEC_SIZE, num_ctas=1)
return lambda: mma_scaled_warp_specialized(A, B, A_scale, B_scale, VEC_SIZE, **BEST_1CTA_CONFIG)
elif variant == "cublas":
A_scale_flat = A_scale.contiguous().flatten()
B_scale_flat = B_scale.contiguous().flatten()
def cublas_fn():
return cublas_block_scaled_matmul(A, B, A_scale_flat, B_scale_flat, a_format)
return cublas_fn
else:
raise ValueError(f"Unknown variant: {variant}")
def make_tensors(MNK, a_format, b_format):
"""Allocate and prepare input tensors for a given size and format."""
M = N = K = MNK
torch.manual_seed(0)
A, A_scale, _ = random_quantized_tensor(M, K, a_format)
B, B_scale, _ = random_quantized_tensor(N, K, b_format)
A_scale = swizzle_scales_packed_block(A_scale)
B_scale = swizzle_scales_packed_block(B_scale)
VEC_SIZE = 16 if a_format == "nvfp4" else 32
return A, B, A_scale, B_scale, VEC_SIZE
def get_variants(a_format, b_format):
"""Return the list of variants available for a given format pair."""
has_cublas = a_format == b_format and a_format in CUBLAS_FORMATS
return ["1cta", "2cta", "cublas"] if has_cublas else ["1cta", "2cta"]
def print_table(label, variants, mnk_vals, results):
"""Print a formatted benchmark table with optional ratio columns."""
has_cublas = "cublas" in variants
col_w = 16
header = f"{'MNK':>8}"
header += f" {'1cta (TFLOPS)':>{col_w}}"
header += f" {'2cta (TFLOPS)':>{col_w}}"
header += f" {'2cta/1cta':>{col_w}}"
if has_cublas:
header += f" {'cublas (TFLOPS)':>{col_w}}"
header += f" {'2cta/cublas':>{col_w}}"
print(f"block-scale-matmul-{label}:")
print(header)
for MNK in mnk_vals:
t1 = results.get((label, "1cta", MNK))
t2 = results.get((label, "2cta", MNK))
ratio_2v1 = t2 / t1 if t1 and t2 else 0.0
row = f"{MNK:>8}"
row += f" {t1:>{col_w}.1f}" if t1 else f" {'--':>{col_w}}"
row += f" {t2:>{col_w}.1f}" if t2 else f" {'--':>{col_w}}"
row += f" {ratio_2v1:>{col_w}.2f}"
if has_cublas:
tc = results.get((label, "cublas", MNK))
ratio_2vc = t2 / tc if t2 and tc else 0.0
row += f" {tc:>{col_w}.1f}" if tc else f" {'--':>{col_w}}"
row += f" {ratio_2vc:>{col_w}.2f}"
print(row)
print()
def format_config(cfg):
"""Format an autotuner Config as a concise string."""
if cfg is None:
return "(none)"
kw = cfg.kwargs
parts = [
f"BM={kw['BLOCK_M']}", f"BN={kw['BLOCK_N']}", f"BK={kw['BLOCK_K']}", f"epilogue_N={kw['EPILOGUE_BLOCK_N']}",
f"bufs={kw['num_buffers']}", f"acc_bufs={kw['num_acc_buffers']}", f"minor={kw['GRID_MINOR_DIM']}",
f"tile_w={kw['GRID_TILE_WIDTH']}", f"cga={kw['CGA_LAYOUT']}"
]
return ", ".join(parts)
def run_benchmark(use_autotuned=False):
results = {}
best_configs = {}
for a_format, b_format in ALL_FORMATS:
label = f"{a_format}-{b_format}"
variants = get_variants(a_format, b_format)
for MNK in MNK_VALS:
A, B, A_scale, B_scale, VEC_SIZE = make_tensors(MNK, a_format, b_format)
for variant in variants:
if use_autotuned:
print(f" {label} {variant} MNK={MNK}: ...", end="", flush=True)
fn = make_fn(variant, A, B, A_scale, B_scale, VEC_SIZE, a_format, use_autotuned=use_autotuned)
ms = triton.testing.do_bench(fn)
tflops = 2.0 * MNK**3 * 1e-12 / (ms * 1e-3)
results[(label, variant, MNK)] = tflops
if use_autotuned:
print(f"\r {label} {variant} MNK={MNK}: {tflops:.1f} TFLOPS")
if variant == "1cta":
best_configs[(label, "1cta", MNK)] = mma_scaled_1cta_kernel.best_config
elif variant == "2cta":
best_configs[(label, "2cta", MNK)] = mma_scaled_2cta_kernel.best_config
if use_autotuned:
largest_mnk = MNK_VALS[-1]
print(f"\nBest autotuned configs (MNK={largest_mnk}):")
for a_format, b_format in ALL_FORMATS:
label = f"{a_format}-{b_format}"
c1 = best_configs.get((label, "1cta", largest_mnk))
c2 = best_configs.get((label, "2cta", largest_mnk))
print(f" {label}:")
print(f" 1cta: {format_config(c1)}")
print(f" 2cta: {format_config(c2)}")
print()
for a_format, b_format in ALL_FORMATS:
label = f"{a_format}-{b_format}"
variants = get_variants(a_format, b_format)
print_table(label, variants, MNK_VALS, results)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Block-scaled matmul benchmark")
parser.add_argument(
"--use-autotuned",
action="store_true",
help="Use autotuned mma_scaled_matmul() instead of mma_scaled_warp_specialized().",
)
args = parser.parse_args()
run_benchmark(use_autotuned=args.use_autotuned)