Matmul Multicta
This example can be found at python/examples/gluon/03-matmul-multicta.py.
import argparse
import pytest
import torch
import triton
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.language.nvidia.blackwell import (
TensorMemoryLayout,
allocate_tensor_memory,
clc,
tcgen05_commit,
tcgen05_mma,
tcgen05_mma_barrier_count,
tensor_memory_descriptor,
)
from triton.experimental.gluon.language.nvidia.hopper import mbarrier, tma
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
def is_blackwell():
if not torch.cuda.is_available():
return False
target = triton.runtime.driver.active.get_current_target()
return target.backend == "cuda" and torch.cuda.get_device_capability()[0] == 10
def as_gl_dtype(torch_dtype):
if torch_dtype == torch.float16:
return gl.float16
if torch_dtype == torch.bfloat16:
return gl.bfloat16
if torch_dtype == torch.float32:
return gl.float32
raise ValueError(f"Unsupported dtype for Gluon layout: {torch_dtype}")
@gluon.constexpr_function
def get_split_dim(cga_layout, dim):
return 1 << sum(b[dim] != 0 for b in cga_layout)
def get_epilogue_size_n(block_m, block_n, cga_layout):
"""
We can't split a layout along one of the first 4 warps
or along a CTA as each of these has their own address space.
It would be possible to split it if we use a smaller BLOCK_N
for both the TMA and MMA instructions but this is NYI.
"""
# We can't split the layout along N as N on M=64 2CTA layouts
# the basis (0, TileN) is owned by the second warp basis!
if block_m == 64 and cga_layout:
return block_n
# We can't split the layout along N as the last basis along N
# is owned by a different CTA!
if get_split_dim(cga_layout, 1) > 1:
return block_n
return 32
def matmul_get_configs(pre_hook=None):
return [
triton.Config(
{
"BLOCK_SIZE_M": BM,
"BLOCK_SIZE_N": BN,
"BLOCK_SIZE_K": BK,
"GRID_MINOR_DIM": minor_dim,
"GRID_TILE_WIDTH": grid_tile_width,
"STAGES": stages,
"ACC_STAGES": acc_stages,
"EPILOGUE_SIZE_N": get_epilogue_size_n(BM, BN, cga_layout),
"SUBTILE_STAGES": subtile_stages,
"CGA_LAYOUT": cga_layout,
},
num_warps=4,
num_ctas=2**len(cga_layout),
pre_hook=pre_hook,
)
for BM in (64, 128)
for BN in (128, 256, 512)
for BK in (64, 128)
for minor_dim in (0, 1)
for grid_tile_width in (4, 8, 16)
for stages in (2, 4, 6)
for acc_stages in (2, )
for subtile_stages in (4, )
for cga_layout in ((), ((1, 0), ), ((1, 0), (2, 0)))
if BN // get_split_dim(cga_layout, 1) <= 256
# Trim some configs with too large a tile
if not (BN == 512 and len(cga_layout) == 0)
]
def matmul_tma_set_block_size_hook(nargs):
block_m = nargs["BLOCK_SIZE_M"]
block_n = nargs["BLOCK_SIZE_N"]
block_k = nargs["BLOCK_SIZE_K"]
epilogue_size_n = nargs["EPILOGUE_SIZE_N"]
cga_layout = nargs["CGA_LAYOUT"]
tile_m = block_m * get_split_dim(cga_layout, 0)
nargs["a_desc"].block_shape = [tile_m, block_k]
nargs["b_desc"].block_shape = [block_k, block_n]
nargs["c_desc"].block_shape = [tile_m, epilogue_size_n]
def get_cga_layout(layout, op_idx):
assert op_idx in (0, 1)
if not layout:
return layout
# 2CTA performs an outer product so bases are [1, 0] and [0, 1]
assert layout[0] == (1, 0)
first = (1, 0) if op_idx == 0 else (0, 1)
# Broadcast along K (the reduction dimension)
# We multiply by 2 for op_idx == 1, as we have added the (0, 1) basis.
def broadcast(b):
return (b[0], 0) if op_idx == 0 else (0, 2 * b[1])
return (first, *map(broadcast, layout[1:]))
cga_layout_a = get_cga_layout(cga_layout, 0)
cga_layout_b = get_cga_layout(cga_layout, 1)
cga_layout_c = cga_layout
for desc, cga_layout in zip(("a_desc", "b_desc", "c_desc"), (cga_layout_a, cga_layout_b, cga_layout_c)):
nargs[desc].layout = gl.NVMMASharedLayout.get_default_for(
nargs[desc].block_shape,
as_gl_dtype(nargs[desc].base.dtype),
cga_layout=cga_layout,
)
# From Pallas / CUTLASS
@gluon.jit
def _planar_snake(lin_idx, m_tiles, n_tiles, minor_dim: gl.constexpr, tile_width: gl.constexpr):
major_size = n_tiles if minor_dim == 0 else m_tiles
minor_size = m_tiles if minor_dim == 0 else n_tiles
full_minor_tiles = minor_size // tile_width
full_minor_size = full_minor_tiles * tile_width
full_elements = full_minor_tiles * tile_width * major_size
minor_tile_idx = lin_idx // (tile_width * major_size)
full_minor_within = lin_idx % tile_width
full_major_within = (lin_idx // tile_width) % major_size
full_minor = minor_tile_idx * tile_width + full_minor_within
full_major = gl.where((minor_tile_idx % 2) == 0, full_major_within, major_size - 1 - full_major_within)
partial_width = minor_size - full_minor_size
partial_width = gl.where(partial_width > 0, partial_width, 1)
partial_lin = lin_idx - full_elements
partial_minor_within = partial_lin % partial_width
partial_major_within = (partial_lin // partial_width) % major_size
partial_minor = minor_tile_idx * tile_width + partial_minor_within
partial_major = gl.where((minor_tile_idx % 2) == 0, partial_major_within, major_size - 1 - partial_major_within)
in_full_tile = lin_idx < full_elements
minor = gl.where(in_full_tile, full_minor, partial_minor)
major = gl.where(in_full_tile, full_major, partial_major)
if minor_dim == 0:
return minor, major
return major, minor
@gluon.aggregate
class Counter:
index: gl.tensor
phase: gl.tensor
num_barriers: gl.constexpr
@gluon.jit
def create(phase, num_barriers: gl.constexpr):
return Counter(gl.to_tensor(0), gl.to_tensor(phase), num_barriers)
@gluon.must_use_result
@gluon.jit
def next(self, pred=True):
incr = self.index + gl.where(pred, 1, 0)
rollover = incr == self.num_barriers
index = gl.where(rollover, 0, incr)
phase = gl.where(rollover, self.phase ^ 1, self.phase)
return Counter(index, phase, self.num_barriers)
@gluon.aggregate
class ClcTileSchedulerConsumer:
has_work: gl.tensor
tile_id: gl.tensor
pid_m: gl.tensor
pid_n: gl.tensor
num_pid_m: gl.tensor
num_pid_n: gl.tensor
TILE_M: gl.constexpr
TILE_N: gl.constexpr
MINOR_DIM: gl.constexpr
GRID_TILE_WIDTH: gl.constexpr
clc_result_buffers: gl.shared_memory_descriptor
clc_barriers: gl.shared_memory_descriptor
clc_planar_pid_buffers: gl.shared_memory_descriptor
clc_planar_ready_bars: gl.shared_memory_descriptor
clc_consumed_bars: gl.shared_memory_descriptor
counter: Counter
consumed_counter: Counter
@gluon.jit
def initialize(M, N, TILE_M: gl.constexpr, TILE_N: gl.constexpr, MINOR_DIM: gl.constexpr,
GRID_TILE_WIDTH: gl.constexpr, clc_result_buffers, clc_barriers, clc_planar_pid_buffers,
clc_planar_ready_bars, clc_consumed_bars):
tile_id = gl.program_id(axis=0)
num_pid_m = gl.cdiv(M, TILE_M)
num_pid_n = gl.cdiv(N, TILE_N)
pid_m, pid_n = _planar_snake(tile_id, num_pid_m, num_pid_n, MINOR_DIM, GRID_TILE_WIDTH)
has_work = gl.to_tensor(True)
counter = Counter.create(0, clc_barriers.shape[0])
consumed_counter = Counter.create(0, clc_barriers.shape[0])
return ClcTileSchedulerConsumer(
has_work,
tile_id,
pid_m,
pid_n,
num_pid_m,
num_pid_n,
TILE_M,
TILE_N,
MINOR_DIM,
GRID_TILE_WIDTH,
clc_result_buffers,
clc_barriers,
clc_planar_pid_buffers,
clc_planar_ready_bars,
clc_consumed_bars,
counter,
consumed_counter,
)
@gluon.jit
def get_offsets(self):
return self.pid_m * self.TILE_M, self.pid_n * self.TILE_N
@gluon.jit
def step(self, iteration):
# The 0-th iteration uses the program_id as the tile_id.
# At the end of each iteration we prefetch the next tile.
# As such we must signal the consumed slot at the end of
# each iteration skipping the first one.
consumed_counter = self.consumed_counter
if iteration > 0:
mbarrier.arrive(self.clc_consumed_bars.index(consumed_counter.index))
consumed_counter = consumed_counter.next()
counter = self.counter
barrier = self.clc_barriers.index(counter.index)
result = self.clc_result_buffers.index(counter.index)
mbarrier.wait(barrier, counter.phase)
clc_res = clc.load_result(result)
mbarrier.wait(self.clc_planar_ready_bars.index(counter.index), counter.phase)
planar_slot = self.clc_planar_pid_buffers.index(counter.index)
planar_layout: gl.constexpr = gl.BlockedLayout([1], [32], [gl.num_warps()], [0],
[[0]] * (gl.num_ctas().bit_length() - 1))
packed_pid = planar_slot.load(planar_layout).reshape([])
pid_m = ((packed_pid >> 32) & 0xFFFFFFFF).to(gl.int32)
pid_n = (packed_pid & 0xFFFFFFFF).to(gl.int32)
has_work = clc_res.is_canceled()
tile_id = self.tile_id
if has_work:
tile_id = clc_res.program_id(0)
return ClcTileSchedulerConsumer(
has_work,
tile_id,
pid_m,
pid_n,
self.num_pid_m,
self.num_pid_n,
self.TILE_M,
self.TILE_N,
self.MINOR_DIM,
self.GRID_TILE_WIDTH,
self.clc_result_buffers,
self.clc_barriers,
self.clc_planar_pid_buffers,
self.clc_planar_ready_bars,
self.clc_consumed_bars,
counter.next(),
consumed_counter,
)
@gluon.aggregate
class PartitionArgs:
a_desc: tma.tensor_descriptor
b_desc: tma.tensor_descriptor
c_desc: tma.tensor_descriptor
a_bufs: gl.shared_memory_descriptor
b_bufs: gl.shared_memory_descriptor
load_empty_bars: gl.shared_memory_descriptor
load_ready_bars: gl.shared_memory_descriptor
acc_bufs: tensor_memory_descriptor
acc_empty_bars: gl.shared_memory_descriptor
acc_ready_bars: gl.shared_memory_descriptor
clc_result_buffers: gl.shared_memory_descriptor
clc_barriers: gl.shared_memory_descriptor
clc_planar_pid_buffers: gl.shared_memory_descriptor
clc_planar_ready_bars: gl.shared_memory_descriptor
clc_consumed_bars: gl.shared_memory_descriptor
MINOR_DIM: gl.constexpr
GRID_TILE_WIDTH: gl.constexpr
SUBTILE_STAGES: gl.constexpr
@gluon.jit
def get_clc_consumer(self):
return ClcTileSchedulerConsumer.initialize(
self.c_desc.shape[0],
self.c_desc.shape[1],
self.a_desc.block_shape[0],
self.b_desc.block_shape[1],
self.MINOR_DIM,
self.GRID_TILE_WIDTH,
self.clc_result_buffers,
self.clc_barriers,
self.clc_planar_pid_buffers,
self.clc_planar_ready_bars,
self.clc_consumed_bars,
)
@gluon.jit
def matmul_clc_partition(p):
TILE_M: gl.constexpr = p.a_desc.block_shape[0]
TILE_N: gl.constexpr = p.b_desc.block_shape[1]
has_work = gl.to_tensor(True)
num_pid_m = gl.cdiv(p.c_desc.shape[0], TILE_M)
num_pid_n = gl.cdiv(p.c_desc.shape[1], TILE_N)
state = Counter.create(0, p.clc_barriers.shape[0])
consumed_state = Counter.create(1, p.clc_barriers.shape[0])
ACC_STAGES: gl.constexpr = p.clc_barriers.shape[0]
i = 0
while has_work:
# Reuse the slot only after all consumer partitions signaled consumed.
mbarrier.wait(p.clc_consumed_bars.index(consumed_state.index), consumed_state.phase, pred=(i >= ACC_STAGES))
barrier = p.clc_barriers.index(state.index)
result = p.clc_result_buffers.index(state.index)
# 16: clc.try_cancel has a `.b128` modifier
mbarrier.expect(barrier, 16)
clc.try_cancel(result, barrier)
mbarrier.wait(barrier, state.phase)
clc_res = clc.load_result(result)
has_work = clc_res.is_canceled()
pid_m = gl.to_tensor(0)
pid_n = gl.to_tensor(0)
if has_work:
tile_id = clc_res.program_id(0)
pid_m, pid_n = _planar_snake(tile_id, num_pid_m, num_pid_n, p.MINOR_DIM, p.GRID_TILE_WIDTH)
packed_pid = (pid_m.to(gl.int64) << 32) | (pid_n.to(gl.int64) & 0xFFFFFFFF)
planar_slot = p.clc_planar_pid_buffers.index(state.index)
planar_layout: gl.constexpr = gl.BlockedLayout([1], [32], [gl.num_warps()], [0],
[[0]] * (gl.num_ctas().bit_length() - 1))
planar_slot.store(gl.full([1], packed_pid, gl.int64, layout=planar_layout))
mbarrier.arrive(p.clc_planar_ready_bars.index(state.index))
state = state.next()
consumed_state = consumed_state.next()
i += 1
@gluon.jit
def matmul_load_partition(p):
BLOCK_K: gl.constexpr = p.a_desc.block_shape[1]
K = p.a_desc.shape[1]
concurrent_loads: gl.constexpr = p.load_ready_bars.shape[0]
state = Counter.create(1, concurrent_loads)
scheduler = p.get_clc_consumer()
i = 0
while scheduler.has_work:
off_m, off_n = scheduler.get_offsets()
for k in range(0, K, BLOCK_K):
pred = (i > 0) or (k >= BLOCK_K * concurrent_loads)
mbarrier.wait(p.load_empty_bars.index(state.index), state.phase, pred=pred)
bar = p.load_ready_bars.index(state.index)
mbarrier.expect(bar, p.a_desc.nbytes_per_cta + p.b_desc.nbytes_per_cta)
tma.async_load(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index), multicast=True)
tma.async_load(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index), multicast=True)
state = state.next()
scheduler = scheduler.step(i)
i += 1
@gluon.jit
def matmul_mma_partition(p):
BLOCK_K: gl.constexpr = p.a_desc.block_shape[1]
K = p.a_desc.shape[1]
ACC_STAGES: gl.constexpr = p.acc_empty_bars.shape[0]
load_state = Counter.create(0, p.load_empty_bars.shape[0])
acc_state = Counter.create(1, ACC_STAGES)
scheduler = p.get_clc_consumer()
i = 0
while scheduler.has_work:
acc_buf = p.acc_bufs.index(acc_state.index)
mbarrier.wait(p.acc_empty_bars.index(acc_state.index), acc_state.phase, pred=(i >= ACC_STAGES))
use_acc = False
for k in range(0, K, BLOCK_K):
mbarrier.wait(p.load_ready_bars.index(load_state.index), load_state.phase)
tcgen05_mma(p.a_bufs.index(load_state.index), p.b_bufs.index(load_state.index), acc_buf, use_acc=use_acc,
multicast=True, mbarriers=[p.load_empty_bars.index(load_state.index)])
load_state = load_state.next()
use_acc = True
tcgen05_commit(p.acc_ready_bars.index(acc_state.index), descs=[p.a_bufs.index(0), p.b_bufs.index(0)])
acc_state = acc_state.next()
scheduler = scheduler.step(i)
i += 1
@gluon.jit
def matmul_epilogue_partition(p):
TILE_M: gl.constexpr = p.a_desc.block_shape[0]
TILE_N: gl.constexpr = p.b_desc.block_shape[1]
SPLIT_TILE_N: gl.constexpr = p.c_desc.block_shape[1]
# Separate knobs: SUBTILE_STAGES controls shared-memory usage,
# and SUBTILE_FACTOR is the maximum number of subtiles into which we can split the tile,
# which might be too large to fit within shared-memory limits.
SUBTILE_FACTOR: gl.constexpr = TILE_N // SPLIT_TILE_N
SUBTILE_STAGES: gl.constexpr = p.SUBTILE_STAGES
ACC_STAGES: gl.constexpr = p.acc_empty_bars.shape[0]
dtype: gl.constexpr = p.c_desc.dtype
acc_state = Counter.create(0, ACC_STAGES)
acc_smems = gl.allocate_shared_memory(dtype, [SUBTILE_STAGES, TILE_M, SPLIT_TILE_N], p.c_desc.layout)
sub_acc_state = Counter.create(0, SUBTILE_STAGES)
scheduler = p.get_clc_consumer()
i = 0
while scheduler.has_work:
off_m, off_n = scheduler.get_offsets()
mbarrier.wait(p.acc_ready_bars.index(acc_state.index), acc_state.phase)
acc_buf = p.acc_bufs.index(acc_state.index)
for s in gl.static_range(SUBTILE_FACTOR):
acc_sub = acc_buf.slice(SPLIT_TILE_N * s, SPLIT_TILE_N)
acc_smem = acc_smems.index(sub_acc_state.index)
acc = acc_sub.load().to(dtype)
tma.store_wait(pendings=SUBTILE_STAGES - 1)
acc_smem.store(acc)
tma.async_copy_shared_to_global(p.c_desc, [off_m, off_n + SPLIT_TILE_N * s], acc_smem)
sub_acc_state = sub_acc_state.next()
# Signal that the accumulator slot can be reused only after all stores are done.
mbarrier.arrive(p.acc_empty_bars.index(acc_state.index))
acc_state = acc_state.next()
scheduler = scheduler.step(i)
i += 1
@gluon.jit
def _matmul_kernel(
a_desc,
b_desc,
c_desc,
M,
N,
K,
BLOCK_SIZE_M: gl.constexpr,
BLOCK_SIZE_N: gl.constexpr,
BLOCK_SIZE_K: gl.constexpr,
GRID_MINOR_DIM: gl.constexpr,
GRID_TILE_WIDTH: gl.constexpr,
STAGES: gl.constexpr,
ACC_STAGES: gl.constexpr,
CGA_LAYOUT: gl.constexpr,
EPILOGUE_SIZE_N: gl.constexpr,
SUBTILE_STAGES: gl.constexpr,
):
BLOCK_M: gl.constexpr = a_desc.block_shape[0]
BLOCK_N: gl.constexpr = b_desc.block_shape[1]
TWO_CTAS: gl.constexpr = gl.num_ctas() > 1
N_PARTITIONS: gl.constexpr = 4
dtype: gl.constexpr = a_desc.dtype
a_bufs = gl.allocate_shared_memory(dtype, [STAGES] + a_desc.block_shape, a_desc.layout)
b_bufs = gl.allocate_shared_memory(dtype, [STAGES] + b_desc.block_shape, b_desc.layout)
# Number of CTAs that will arrive on the barrier from a tcgen05_commit after an MMA instruction
mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True)
# Equiv. consumed_barrier. Barrier TCGEN05 MMA -> Load TMA
load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES)
# Equiv. ab_tma_barrier. Barrier Load TMA -> TCGEN05 MMA
load_ready_bars = mbarrier.allocate_mbarrier(batch=STAGES, two_ctas=TWO_CTAS)
for i in gl.static_range(STAGES):
mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count)
mbarrier.init(load_ready_bars.index(i), count=1)
tmem_layout: gl.constexpr = TensorMemoryLayout(
[BLOCK_SIZE_M, BLOCK_N // get_split_dim(CGA_LAYOUT, 1)],
col_stride=1,
cga_layout=CGA_LAYOUT,
two_ctas=TWO_CTAS,
)
acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, BLOCK_M, BLOCK_N], tmem_layout)
# Equiv. store_done_barrier. Barrier Store TMA -> TCGEN05 MMA
acc_empty_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=TWO_CTAS)
# Equiv. mma_done_barrier. Barrier TCGEN05 MMA -> Store TMA
acc_ready_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES)
for i in gl.static_range(ACC_STAGES):
mbarrier.init(acc_empty_bars.index(i), count=1)
mbarrier.init(acc_ready_bars.index(i), count=mma_barrier_count)
clc_barriers = mbarrier.allocate_mbarrier(batch=ACC_STAGES)
clc_planar_ready_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES)
clc_consumed_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=TWO_CTAS)
for i in gl.static_range(ACC_STAGES):
mbarrier.init(clc_barriers.index(i), count=1)
mbarrier.init(clc_planar_ready_bars.index(i), count=1)
# Every partition but itself arrives on the barrier
mbarrier.init(clc_consumed_bars.index(i), count=N_PARTITIONS - 1)
cga_layout: gl.constexpr = [[0]] * (gl.num_ctas().bit_length() - 1)
clc_layout: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, [0], cga_layout=cga_layout)
clc_result_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 2], clc_layout)
clc_planar_pid_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 1], clc_layout)
p = PartitionArgs(
a_desc,
b_desc,
c_desc,
a_bufs,
b_bufs,
load_empty_bars,
load_ready_bars,
acc_bufs,
acc_empty_bars,
acc_ready_bars,
clc_result_buffers,
clc_barriers,
clc_planar_pid_buffers,
clc_planar_ready_bars,
clc_consumed_bars,
GRID_MINOR_DIM,
GRID_TILE_WIDTH,
SUBTILE_STAGES,
)
gl.warp_specialize([
(matmul_epilogue_partition, (p, )),
(matmul_load_partition, (p, )),
(matmul_mma_partition, (p, )),
(matmul_clc_partition, (p, )),
], [1, 1, 1], [24, 24, 24])
matmul_kernel = triton.autotune(
configs=matmul_get_configs(pre_hook=matmul_tma_set_block_size_hook),
key=["M", "N", "K"],
)(_matmul_kernel)
def matmul_with_config(
a,
b,
out=None,
*,
block_size_m,
block_size_n,
block_size_k,
grid_minor_dim,
grid_tile_width,
stages,
acc_stages,
cga_layout,
epilogue_size_n,
subtile_stages,
):
if block_size_n // get_split_dim(cga_layout, 1) > 256:
raise ValueError(
f"cga_layout={list(cga_layout)} only supports BLOCK_SIZE_N <= {256 * get_split_dim(cga_layout, 1)}")
M, K = a.shape
K1, N = b.shape
if K != K1:
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
if a.dtype != torch.float16 or b.dtype != torch.float16:
raise ValueError("matmul only supports fp16 inputs")
if out is None:
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
else:
if out.shape != (M, N):
raise ValueError(f"Output has invalid shape {out.shape}, expected {(M, N)}")
if out.device != a.device or out.dtype != a.dtype:
raise ValueError("Output must match input device and dtype")
c = out
dummy_block = [1, 1]
dummy_layout = gl.NVMMASharedLayout.get_default_for(dummy_block, gl.float16)
a_desc = TensorDescriptor.from_tensor(a, dummy_block, dummy_layout)
b_desc = TensorDescriptor.from_tensor(b, dummy_block, dummy_layout)
c_desc = TensorDescriptor.from_tensor(c, dummy_block, dummy_layout)
matmul_tma_set_block_size_hook({
"a_desc": a_desc,
"b_desc": b_desc,
"c_desc": c_desc,
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GRID_MINOR_DIM": grid_minor_dim,
"GRID_TILE_WIDTH": grid_tile_width,
"STAGES": stages,
"ACC_STAGES": acc_stages,
"CGA_LAYOUT": cga_layout,
"EPILOGUE_SIZE_N": epilogue_size_n,
})
def grid(meta):
tile_m = meta["BLOCK_SIZE_M"] * (2 if bool(meta["CGA_LAYOUT"]) else 1)
tile_n = meta["BLOCK_SIZE_N"]
num_tiles = triton.cdiv(M, tile_m) * triton.cdiv(N, tile_n)
return (num_tiles, )
_matmul_kernel[grid](
a_desc,
b_desc,
c_desc,
M,
N,
K,
block_size_m,
block_size_n,
block_size_k,
grid_minor_dim,
grid_tile_width,
stages,
acc_stages,
cga_layout,
epilogue_size_n,
subtile_stages,
num_warps=4,
num_ctas=2**len(cga_layout),
)
return c
def matmul(a, b):
M, K = a.shape
K1, N = b.shape
if K != K1:
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
if a.dtype != torch.float16 or b.dtype != torch.float16:
raise ValueError("matmul only supports fp16 inputs")
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
dummy_block = [1, 1]
dummy_layout = gl.NVMMASharedLayout.get_default_for(dummy_block, gl.float16)
a_desc = TensorDescriptor.from_tensor(a, dummy_block, dummy_layout)
b_desc = TensorDescriptor.from_tensor(b, dummy_block, dummy_layout)
c_desc = TensorDescriptor.from_tensor(c, dummy_block, dummy_layout)
def grid(meta):
tile_m = meta["BLOCK_SIZE_M"] * (2 if bool(meta["CGA_LAYOUT"]) else 1)
tile_n = meta["BLOCK_SIZE_N"]
num_tiles = triton.cdiv(M, tile_m) * triton.cdiv(N, tile_n)
return (num_tiles, )
matmul_kernel[grid](a_desc, b_desc, c_desc, M, N, K)
return c
# Subset of matmul_get_configs
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
@pytest.mark.parametrize("BLOCK_SIZE_M", [64, 128])
@pytest.mark.parametrize("BLOCK_SIZE_N", [128, 256])
@pytest.mark.parametrize("BLOCK_SIZE_K", [64, 128])
@pytest.mark.parametrize("GRID_MINOR_DIM", [0, 1])
@pytest.mark.parametrize("GRID_TILE_WIDTH", [8])
@pytest.mark.parametrize("CGA_LAYOUT", [(), ((1, 0), ), ((1, 0), (2, 0))])
@pytest.mark.parametrize("STAGES", [2, 4])
@pytest.mark.parametrize("ACC_STAGES", [2])
@pytest.mark.parametrize("EPILOGUE_SIZE_N", [32])
@pytest.mark.parametrize("SUBTILE_STAGES", [4])
@pytest.mark.parametrize("M, N, K", [(100, 200, 200)])
def test_matmul_matches_torch(
M,
N,
K,
BLOCK_SIZE_M,
BLOCK_SIZE_N,
BLOCK_SIZE_K,
GRID_MINOR_DIM,
GRID_TILE_WIDTH,
CGA_LAYOUT,
STAGES,
ACC_STAGES,
EPILOGUE_SIZE_N,
SUBTILE_STAGES,
):
# To support epilogue splitting we need to be able to split within a CTA
EPILOGUE_SIZE_N = get_epilogue_size_n(BLOCK_SIZE_M, BLOCK_SIZE_N, CGA_LAYOUT)
torch.manual_seed(0)
a = torch.rand((M, K), device=torch.device("cuda"), dtype=torch.float16)
b = torch.rand((K, N), device=torch.device("cuda"), dtype=torch.float16)
expected = torch.matmul(a, b)
try:
actual = matmul_with_config(
a,
b,
block_size_m=BLOCK_SIZE_M,
block_size_n=BLOCK_SIZE_N,
block_size_k=BLOCK_SIZE_K,
grid_minor_dim=GRID_MINOR_DIM,
grid_tile_width=GRID_TILE_WIDTH,
stages=STAGES,
acc_stages=ACC_STAGES,
cga_layout=CGA_LAYOUT,
epilogue_size_n=EPILOGUE_SIZE_N,
subtile_stages=SUBTILE_STAGES,
)
except triton.OutOfResources:
pytest.skip("Out of resources")
torch.testing.assert_close(expected, actual, atol=1e-1, rtol=1e-2)
########################################################
# Benchmarking
########################################################
def show_profile(profile_name):
import triton.profiler.viewer as proton_viewer
metric_names = ["tflop16/s", "time/ms"]
file_name = f"{profile_name}.hatchet"
tree, metrics = proton_viewer.parse(metric_names, file_name)
proton_viewer.print_tree(tree, metrics)
def print_benchmark_header():
print("=" * 60)
print("Gluon Matmul Benchmark")
print("=" * 60)
props = torch.cuda.get_device_properties(0)
print(f"Device: {props.name}, SMs: {props.multi_processor_count}")
def create_benchmark_tensors():
M, N, K = 4096, 8192, 4096
print(f"Matrix: M={M}, N={N}, K={K}")
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
c_triton = torch.empty((M, N), device="cuda", dtype=torch.float16)
c_torch = torch.empty((M, N), device="cuda", dtype=torch.float16)
expected = torch.matmul(a, b)
return (M, N, K), a, b, c_triton, c_torch, expected
def get_benchmark_kernel_config():
return {
"tile_m": 128,
"tile_n": 256,
"tile_k": 64,
"grid_minor_dim": 0,
"grid_tile_width": 16,
"stages": 6,
"acc_stages": 2,
"cga_layout": ((1, 0), ),
"epilogue_tile_n": 32,
"subtile_stages": 4,
}
def make_gluon_runner(a, b, c_triton, cfg, use_autotuned=False):
if use_autotuned:
def run_gluon():
return matmul(a, b)
return run_gluon
def run_gluon():
return matmul_with_config(
a,
b,
out=c_triton,
block_size_m=cfg["tile_m"],
block_size_n=cfg["tile_n"],
block_size_k=cfg["tile_k"],
grid_minor_dim=cfg["grid_minor_dim"],
grid_tile_width=cfg["grid_tile_width"],
stages=cfg["stages"],
acc_stages=cfg["acc_stages"],
cga_layout=cfg["cga_layout"],
epilogue_size_n=cfg["epilogue_tile_n"],
subtile_stages=cfg["subtile_stages"],
)
return run_gluon
def run_profile(shape, a, b, c_torch, run_gluon):
import triton.profiler as proton
M, N, K = shape
proton.start("matmul", hook="triton")
proton.deactivate(0)
l2_cache = triton.runtime.driver.active.get_empty_cache_for_benchmark()
def bench_fn(label, reps, fn):
print(f"Benchmarking {label}: ...", end="")
proton.deactivate()
for _ in range(5):
fn() # warmup
triton.runtime.driver.active.clear_cache(l2_cache)
try:
for _ in range(reps):
proton.deactivate()
triton.runtime.driver.active.clear_cache(l2_cache)
proton.activate()
fn()
finally:
proton.deactivate()
print(f"\rBenchmarking {label}: done")
bytes_per_elem = a.element_size()
scope_metrics = {
"bytes": bytes_per_elem * (M * K + N * K + M * N),
f"flops{bytes_per_elem * 8}": 2.0 * M * N * K,
}
def torch_profiled():
with proton.scope(f"torch [M={M}, N={N}, K={K}]", scope_metrics):
torch.matmul(a, b, out=c_torch)
def gluon_profiled():
with proton.scope(f"gluon [M={M}, N={N}, K={K}]", scope_metrics):
run_gluon()
bench_fn("torch", reps=100, fn=torch_profiled)
bench_fn("gluon", reps=100, fn=gluon_profiled)
proton.finalize()
print("Proton profile written to `matmul.hatchet`")
show_profile("matmul")
def benchmark(*, profile=True, use_autotuned=False):
if not is_blackwell():
raise RuntimeError("This benchmark requires a Blackwell CUDA GPU.")
print_benchmark_header()
shape, a, b, c_triton, c_torch, expected = create_benchmark_tensors()
kernel_cfg = get_benchmark_kernel_config()
runner_name = "matmul (autotuned)" if use_autotuned else "matmul_with_config"
print(f"Gluon runner: {runner_name}")
run_gluon = make_gluon_runner(a, b, c_triton, kernel_cfg, use_autotuned=use_autotuned)
actual = run_gluon()
torch.testing.assert_close(actual, expected, atol=1e-1, rtol=1e-2)
if use_autotuned:
print(f"Autotuned best config: {matmul_kernel.best_config}")
if not profile:
print("Skipping profiling (--no-profile).")
return
run_profile(shape, a, b, c_torch, run_gluon)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Gluon matmul benchmark")
"""
To enable NCU profiling, run the script with the following example command:
```
ncu --target-processes all \
--set full \
--import-source yes \
--kernel-name-base function \
--kernel-name 'regex:.*_matmul_kernel.*' \
--launch-count 1 \
-o ncu_triton_matmul \
python 02-matmul-multicta.py --no-profile
```
"""
parser.add_argument(
"--no-profile",
action="store_true",
help="Skip Proton profiling and exit after validation.",
)
parser.add_argument(
"--use-autotuned",
action="store_true",
help="Use autotuned matmul() instead of matmul_with_config() for the Gluon runner.",
)
args = parser.parse_args()
benchmark(profile=not args.no_profile, use_autotuned=args.use_autotuned)