Warp-Group MMA
Warp-Group MMA (also known as WGMMA or MMAv3) is a Hopper-specific instruction for performing matrix multiply-accumulate operations using the Tensor Cores. WGMMA instructions are asynchronous, meaning they can be pipelined.
In this tutorial, we will cover how to use WGMMAs in Gluon. We will build a simple matmul kernel to demonstrate practical uses of WGMMA, and show an example where WGMMAs can be pipelined for better performance.
import pytest
import torch
import triton
import itertools
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
from triton.experimental.gluon.language.nvidia.hopper import (
tma,
mbarrier,
fence_async_shared,
warpgroup_mma_init,
warpgroup_mma,
warpgroup_mma_wait,
)
def is_hopper():
target = triton.runtime.driver.active.get_current_target()
return target.backend == "cuda" and torch.cuda.get_device_capability()[0] == 9
if __name__ == "__main__" and not is_hopper():
raise RuntimeError("This tutorial requires a Hopper NVIDIA GPU")
Let’s illustrate WGMMA with a trivial kernel launched with grid size (1, ). This kernel performs MMA on a small tensor.
warpgroup_mma performs d = a * b + c. The a operand can be passed as
registers or through shared memory. The b operand must be passed through
shared memory, and the c operand must be passed through registers.
warpgroup_mma itself is composed of many smaller wgmma.mma_async PTX
instructions, which supports a limited set of instruction shapes.
The instruction shape is specified as [m, n, k], where
kis always 256 / A.dtype.primitive_bitwidthmis always 16ncan be can chosen as follows:
For floating point dtypes, n must be a positive multiple of 8, up to and
including 256. WGMMA supports 8-bit integers, but n must be chosen from:
224, 208, 192, 176, 160, 144, 128, 112, 96, 80, 64, 48, 32, 24, 16, 8
n must be chosen such that it evenly divides into BLOCK_N, the inner
dimension of the MMA tile, and it must be less than or equal to maxN, where
maxN is computed as:
mReps = ceildiv(M, m)
nReps = ceildiv(num_warps, mReps)
maxN = max(N // nReps, 8)
warpgroup_mma divides the MMA across warps using warps_per_cta, in the
same way BlockedLayout.warps_per_cta tiles a tensor across warps. The
smallest indivisible unit of warps_per_cta is [4, 1]. Note that this
means WGMMA requires at least 4 warps, which together make up one warp group.
To choose the right warps_per_cta, start from the atom [4, 1] and simply
double it along any dimension until it matches the number of warps. Note that
since m=16 and must be at least 4 wraps along M, the M dimension must be at
least 64.
Note when num_warps=8, we can choose [4, 2] or [8, 1], but recall from
02-layouts that this can affect the performance of, e.g., reductions.
warpgroup_mma is an asynchronous operation whose completion is tracked by commit groups, like async copies and TMA stores. Issuing a WGMMA operation implicitly commits it to a WGMMA group, and we can wait until there are N outstanding operations.
Because warpgroup_mma is an asynchronous, until the operation is complete, we cannot access the result even though it is in registers, and we cannot write to any of the shared memory inputs. WGMMA accesses shared memory through the async proxy. Since TMAs also access shared memory through the async proxy, we don’t need fences between TMA and WGMMA instructions.
b_smem.store(b)
fence_async_shared()
warpgroup_mma(a, b_smem, c, is_async=True)
A fence is needed between the shared store and warpgroup_mma to order their shared memory accesses.
Completion of the WGMMA implies its reads from shared memory are complete. Thus, it is safe to write to the shared memory inputs after waiting:
d = warpgroup_mma(a, b_smem, c, is_async=True)
d = warpgroup_mma_wait(num_outstanding=0, deps=(d, ))
b_smem.store(b)
If the LHS operand is supplied in registers via a shared load, completion of the WGMMA implies the shared load is complete, and subsequent accesses to the buffer via the async proxy do not require a fence:
a = a_smem.load(dot_operand_layout)
d = warpgroup_mma(a, b_smem, c, is_async=True)
d = warpgroup_mma_wait(num_outstanding=0, deps=(d, ))
tma.async_load(a_desc, [0, 0], bar, a_smem)
Let’s implement a simple matmul kernel that uses WGMMA.
@gluon.jit
def small_mma_kernel(a_desc, b_desc, c_desc, d_desc, #
LHS_IN_REG: gl.constexpr, INSTR_SHAPE_N: gl.constexpr, num_warps: gl.constexpr):
# Load A, B, and C tiles.
bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(bar, count=1)
# A has shape [M, K].
a_smem = gl.allocate_shared_memory(a_desc.dtype, a_desc.block_type.shape, a_desc.layout)
# B has shape [K, N].
b_smem = gl.allocate_shared_memory(b_desc.dtype, b_desc.block_type.shape, b_desc.layout)
# C has shape [M, N].
c_smem = gl.allocate_shared_memory(c_desc.dtype, c_desc.block_type.shape, c_desc.layout)
mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes + c_desc.block_type.nbytes)
tma.async_load(a_desc, [0, 0], bar, a_smem)
tma.async_load(b_desc, [0, 0], bar, b_smem)
tma.async_load(c_desc, [0, 0], bar, c_smem)
mbarrier.wait(bar, phase=0)
mbarrier.invalidate(bar)
# Let's parameterize the kernel over LHS_IN_REG and INSTR_SHAPE_N to see how
# it can affect performance.
m: gl.constexpr = 16
k: gl.constexpr = 256 // a_desc.dtype.primitive_bitwidth
n: gl.constexpr = INSTR_SHAPE_N
warps_per_cta: gl.constexpr = [num_warps, 1]
# The MMA shape is passed through the layout of `c`, which must always have
# an NVMMADistributedLayout.
c_layout: gl.constexpr = gl.NVMMADistributedLayout(
version=[3, 0],
warps_per_cta=warps_per_cta,
instr_shape=[m, n, k],
)
# When A is passed through registers, it must have the following layout:
a_reg_layout: gl.constexpr = gl.DotOperandLayout(
operand_index=0,
parent=c_layout,
k_width=32 // a_desc.dtype.primitive_bitwidth,
)
# When an operand is passed through shared memory, it must have an
# NVMMASharedLayout. TMA requires using an NVMMASharedLayout.
gl.static_assert(isinstance(a_smem.type.layout, gl.NVMMASharedLayout))
gl.static_assert(isinstance(b_smem.type.layout, gl.NVMMASharedLayout))
if LHS_IN_REG:
a = a_smem.load(a_reg_layout)
else:
a = a_smem
c = c_smem.load(c_layout)
# Issue the async WGMMA. Note that `is_async=False` is the default value,
# and all this does is immediately wait for 0 outstanding operations. In
# this tutorial, we will always use `is_async=True`.
#
# Another important flag to consider is `use_acc`. When `use_acc=False`, the
# `c` input is ignored and the accumulator is zero-initialized. This can be
# an efficient way to zero the accumulator.
d = warpgroup_mma(a, b_smem, c, is_async=True, use_acc=True)
# To ensure correct ordering between `warpgroup_mma`, the wait, and uses of
# the result, you must thread the `warpgroup_mma` result through the wait
# via the `deps` argument and use the return value of the
# `warpgroup_mma_wait`.
#
# Wait for 0 outstanding operations, so we know the WGMMA is complete.
d = warpgroup_mma_wait(num_outstanding=0, deps=(d, ))
d_smem = gl.allocate_shared_memory(d_desc.dtype, d_desc.block_type.shape, d_desc.layout)
d_smem.store(d)
fence_async_shared()
tma.async_copy_shared_to_global(d_desc, [0, 0], d_smem)
tma.store_wait(pendings=0)
def small_mma(A, B, C, D, INSTR_SHAPE_N, LHS_IN_REG=False, num_warps=4):
a_layout = gl.NVMMASharedLayout.get_default_for(A.shape, gl.float16)
b_layout = gl.NVMMASharedLayout.get_default_for(B.shape, gl.float16)
cd_layout = gl.NVMMASharedLayout.get_default_for(C.shape, gl.float32)
a_desc = TensorDescriptor.from_tensor(A, A.shape, a_layout)
b_desc = TensorDescriptor.from_tensor(B, B.shape, b_layout)
c_desc = TensorDescriptor.from_tensor(C, C.shape, cd_layout)
d_desc = TensorDescriptor.from_tensor(D, D.shape, cd_layout)
small_mma_kernel[(1, )](
a_desc, b_desc, c_desc, d_desc, #
LHS_IN_REG, INSTR_SHAPE_N, num_warps=num_warps)
@pytest.mark.parametrize("M, N, K", [(64, 32, 32), (64, 256, 128)])
@pytest.mark.parametrize("LHS_IN_REG", [False, True])
@pytest.mark.parametrize("INSTR_SHAPE_N", [16, 64])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper")
def test_small_mma(M, N, K, LHS_IN_REG, INSTR_SHAPE_N, num_warps):
maxN = max(N // triton.cdiv(num_warps, triton.cdiv(M, 16)), 8)
if INSTR_SHAPE_N > maxN:
pytest.skip(f"INSTR_SHAPE_N={INSTR_SHAPE_N} is too large for M={M}, N={N}, num_warps={num_warps}")
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.randn(M, N, device="cuda", dtype=torch.float32)
D = torch.empty_like(C)
small_mma(A, B, C, D, INSTR_SHAPE_N, LHS_IN_REG, num_warps)
torch.testing.assert_close(A @ B + C, D, atol=1e-3, rtol=1e-1)
Let’s study the performance impact of our knobs on WGMMA.
if __name__ == "__main__":
print("Benchmarking WGMMA")
print("==================")
M, N, K = 64, 128, 128
num_warps = 4
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.randn(M, N, device="cuda", dtype=torch.float32)
D = torch.empty_like(C)
print("LHS_IN_REG INSTR_SHAPE_N time (us)")
for LHS_IN_REG, INSTR_SHAPE_N in itertools.product([False, True], [16, 32, 64, 128]):
fn = lambda: small_mma(A, B, C, D, INSTR_SHAPE_N, LHS_IN_REG, num_warps)
ms = triton.testing.do_bench(fn)
print(f"{LHS_IN_REG!s:>10} {INSTR_SHAPE_N:>13} {ms*1000:>9.2f}")
print()
LHS_IN_REG INSTR_SHAPE_N time (us)
False 16 9.47
False 32 8.48
False 64 8.32
False 128 8.32
True 16 9.32
True 32 8.60
True 64 8.37
True 128 8.36
Picking the largest N results in the best performance, because each
wgmma.mma_async instruction will process more data. In our case, placing LHS
in registers is slower because we had to load the data out of shared memory.
However, if the data was already in registers, it would be faster to use it in
registers instead of placing it in shared memory.
Just like warpgroup_mma is composed of multiple wgmma.mma_async
instructions tiled to cover our block size, we can also tile warpgroup_mma
to cover a much larger matmul. We can tile along K within each kernel and span
(M, N) with multiple programs. This leads to the classic blocked matmul
implementation. Let’s implement a basic version to demonstrate WGMMA.
# This decorator allows us to invoke the function from a Gluon constexpr.
@gluon.constexpr_function
def get_warps_per_cta(BLOCK_M, BLOCK_N, num_warps):
warps_per_cta = [4, 1]
m = 16
# Tile the atom until we have enough warps.
while warps_per_cta[0] * warps_per_cta[1] != num_warps:
# Tile along M only if it would not cause broadcasting.
if BLOCK_M > m * warps_per_cta[0]:
warps_per_cta[0] *= 2
else:
warps_per_cta[1] *= 2
return warps_per_cta
@gluon.constexpr_function
def get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps):
m = 16
mReps = triton.cdiv(BLOCK_M, m)
nReps = triton.cdiv(num_warps, mReps)
maxN = max(BLOCK_N // nReps, 8)
n = 256
while n > maxN or BLOCK_N % n != 0:
n -= 8
assert n >= 8, "expected to find a valid n"
return n
@gluon.constexpr_function
def pick_wgmma_layout(dtype, BLOCK_M, BLOCK_N, num_warps):
m = 16
k = 256 // dtype.primitive_bitwidth
n = get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps)
warps_per_cta = get_warps_per_cta(BLOCK_M, BLOCK_N, num_warps)
return gl.NVMMADistributedLayout(
version=[3, 0],
warps_per_cta=warps_per_cta,
instr_shape=[m, n, k],
)
@gluon.jit
def blocked_matmul_kernel(a_desc, b_desc, c_desc, #
TRANSPOSE_B: gl.constexpr, num_warps: gl.constexpr):
BLOCK_M: gl.constexpr = c_desc.block_type.shape[0]
BLOCK_N: gl.constexpr = c_desc.block_type.shape[1]
BLOCK_K: gl.constexpr = a_desc.block_type.shape[1]
dtype: gl.constexpr = a_desc.dtype
K = a_desc.shape[1]
a_smem = gl.allocate_shared_memory(dtype, a_desc.block_type.shape, a_desc.layout)
b_smem = gl.allocate_shared_memory(dtype, b_desc.block_type.shape, b_desc.layout)
# The block of C this program is processing is (pid_m, pid_n).
pid_m = gl.program_id(axis=0)
pid_n = gl.program_id(axis=1)
off_m = pid_m * BLOCK_M
off_n = pid_n * BLOCK_N
# Determine the WGMMA layout.
mma_layout: gl.constexpr = pick_wgmma_layout(dtype, BLOCK_M, BLOCK_N, num_warps)
acc = gl.zeros((BLOCK_M, BLOCK_N), dtype=gl.float32, layout=mma_layout)
bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(bar, count=1)
phase = 0
for k in range(0, K, BLOCK_K):
# Load tiles of A and B.
mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes)
tma.async_load(a_desc, [off_m, k], bar, a_smem)
if TRANSPOSE_B:
tma.async_load(b_desc, [off_n, k], bar, b_smem)
else:
tma.async_load(b_desc, [k, off_n], bar, b_smem)
mbarrier.wait(bar, phase=phase)
phase ^= 1 # toggle the parity phase between 0 and 1
# We can transpose B by creating a transposed view over tile of B in
# shared memory. This forwards the transposition to WGMMA, which handles
# it for us.
if TRANSPOSE_B:
b = b_smem.permute((1, 0))
else:
b = b_smem
acc = warpgroup_mma(a_smem, b, acc, is_async=True)
acc = warpgroup_mma_wait(num_outstanding=0, deps=(acc, ))
mbarrier.invalidate(bar)
# Downcast accumulator and store tile of C.
c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout)
c_smem.store(acc.to(dtype))
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [off_m, off_n], c_smem)
tma.store_wait(pendings=0)
def blocked_matmul(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, TRANSPOSE_B, num_warps):
M, N = C.shape
a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16)
a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout)
B_BLOCK_SHAPE = [BLOCK_N, BLOCK_K] if TRANSPOSE_B else [BLOCK_K, BLOCK_N]
b_layout = gl.NVMMASharedLayout.get_default_for(B_BLOCK_SHAPE, gl.float16)
b_desc = TensorDescriptor.from_tensor(B, B_BLOCK_SHAPE, b_layout)
c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16)
c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout)
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
blocked_matmul_kernel[grid](a_desc, b_desc, c_desc, TRANSPOSE_B, num_warps=num_warps)
@pytest.mark.parametrize("M, N, K", [(208, 416, 304), (2000, 1000, 2000)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(64, 64, 64), (128, 128, 128)])
@pytest.mark.parametrize("TRANSPOSE_B", [False, True])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper")
def test_blocked_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, TRANSPOSE_B, num_warps):
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn((N, K) if TRANSPOSE_B else (K, N), device="cuda", dtype=torch.float16)
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
blocked_matmul(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, TRANSPOSE_B, num_warps)
C_ref = A @ (B.T if TRANSPOSE_B else B)
torch.testing.assert_close(C_ref, C, rtol=1e-3, atol=1e-1)
We can benchmark this kernel as a baseline, but we need to pick the best block sizes. Rather than autotuning over all possibilities, we can apply some principles to narrow down the search space.
We should try to pick the largest n for the WGMMA layout. Based on the
formula for maxN this requires BLOCK_N>=256. Because our kernel does not
overlap the TMA loads with WGMMA, we will want more than program resident on
each SM so that when one kernel stalls, the SM can switch to the other. This
is known as “occupancy”. In detail, each SM has limited resources, and the
resource usage of a kernel determines its max occupancy. The SM schedules work
by warp using its warp scheduler, which can efficiently swap executing warps,
almost like hyperthreading.
Based on register and smem constraints, we can filter configs for the desired occupancy. Keep in mind that these are rules of thumb. It’s hard to know for sure if these lead to the best block sizes.
def find_configs(occupancy, dtype, num_buffers=1):
dtype_bytes = torch.tensor([], dtype=dtype).element_size()
# Assume ~1 KB of smem used by mbarriers, compiler-generated code, etc.
smem = 228 * 1024 // occupancy - 1024
configs = []
BLOCK_MNK = [32, 64, 128, 256]
for BLOCK_M, BLOCK_N, BLOCK_K, num_warps in itertools.product(BLOCK_MNK, BLOCK_MNK, BLOCK_MNK, [4, 8]):
# Assume ~16 regs per thread of baseline usage.
regs = 64 * 1024 // occupancy - 16 * num_warps * 32
a_smem = BLOCK_M * BLOCK_K * dtype_bytes
b_smem = BLOCK_N * BLOCK_K * dtype_bytes
acc_smem = BLOCK_M * BLOCK_N * dtype_bytes
# SMEM for A and B does not coexist with C.
if max((a_smem + b_smem) * num_buffers, acc_smem) > smem:
continue
# The accumulator is the only in-memory tensor in f32.
acc_regs = BLOCK_M * BLOCK_N
# Max regs per thread is 256. Being near this can also cause spills.
if acc_regs // num_warps // 32 >= 256:
continue
if acc_regs > regs:
continue
instr_shape_n = get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps)
configs.append((BLOCK_M, BLOCK_N, BLOCK_K, num_warps, instr_shape_n, occupancy))
def filter_configs(configs, instr_shape_n):
max_n_configs = [cfg for cfg in configs if cfg[4] == instr_shape_n]
# Filter for configs with the largest BLOCK_M * BLOCK_K.
max_block_mk = max(cfg[0] * cfg[2] for cfg in max_n_configs)
return [cfg for cfg in max_n_configs if cfg[0] * cfg[2] == max_block_mk]
top_instr_shape_n = sorted({cfg[4] for cfg in configs}, reverse=True)
result_configs = filter_configs(configs, top_instr_shape_n[0])
if len(top_instr_shape_n) > 1:
result_configs += filter_configs(configs, top_instr_shape_n[1])
return result_configs
if __name__ == "__main__":
print("Benchmarking selected configs")
print("=============================")
# Just in case, check occupancy 1 configs.
configs = find_configs(occupancy=1, dtype=torch.float16)
configs += find_configs(occupancy=2, dtype=torch.float16)
# Benchmark the configs over a large matmul. Keep in mind that the best
# hyperparameters can depend on the matmul shapes.
M, N, K = 8192, 8192, 16 * 1024
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
print("BLOCK_M BLOCK_N BLOCK_K num_warps instr_shape_n occupancy time (ms) tflops/s")
for BLOCK_M, BLOCK_N, BLOCK_K, num_warps, instr_shape_n, occupancy in configs:
fn = lambda: blocked_matmul(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, False, num_warps)
ms = triton.testing.do_bench(fn)
flops = 2 * M * N * K
tflops_per_sec = flops * 1e-12 / (ms * 1e-3)
print(f"{BLOCK_M:>7} {BLOCK_N:>7} {BLOCK_K:>7} {num_warps:>9} {instr_shape_n:>13} "
f"{occupancy:>9} {ms:>9.2f} {tflops_per_sec:>8.2f}")
print()
BLOCK_M BLOCK_N BLOCK_K num_warps instr_shape_n occupancy time (ms) tflops/s
128 256 256 8 256 1 5.34 412.14
256 128 256 8 128 1 5.67 387.74
64 256 128 4 256 2 4.64 474.03
64 128 256 4 128 2 6.18 355.60
128 128 128 4 128 2 4.98 441.88
128 128 128 8 128 2 5.79 380.08
The hypothesis that having occupancy 2 with BLOCK_N=256 would be the best
has held over our limited sample of hyperparameters. Autotuning over all
hyperparameters is an exercise for the reader.
466 TFLOPS is not a bad start. However, we aren’t using the fact that WGMMA is asynchronous, and we aren’t pipelining the TMA loads as shown in previous tutorials.
For now, let’s keep the loads synchronous and focus on pipelining the WGMMA. This requires us to double-buffer the operands, since we will be loading into the next set of buffers while WGMMA reads from the previous.
@gluon.jit
def blocked_matmul_pipelined_kernel(a_desc, b_desc, c_desc, num_warps: gl.constexpr):
BLOCK_M: gl.constexpr = c_desc.block_type.shape[0]
BLOCK_N: gl.constexpr = c_desc.block_type.shape[1]
BLOCK_K: gl.constexpr = a_desc.block_type.shape[1]
dtype: gl.constexpr = a_desc.dtype
K = a_desc.shape[1]
# Allocate 2 buffers for each A and B.
a_smem = gl.allocate_shared_memory(dtype, [2] + a_desc.block_type.shape, a_desc.layout)
b_smem = gl.allocate_shared_memory(dtype, [2] + b_desc.block_type.shape, b_desc.layout)
index = 0
pid_m = gl.program_id(axis=0)
pid_n = gl.program_id(axis=1)
off_m = pid_m * BLOCK_M
off_n = pid_n * BLOCK_N
mma_layout: gl.constexpr = pick_wgmma_layout(dtype, BLOCK_M, BLOCK_N, num_warps)
acc = warpgroup_mma_init(gl.zeros((BLOCK_M, BLOCK_N), dtype=gl.float32, layout=mma_layout))
bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(bar, count=1)
phase = 0
for k in range(0, K, BLOCK_K):
a = a_smem.index(index)
b = b_smem.index(index)
mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes)
tma.async_load(a_desc, [off_m, k], bar, a)
tma.async_load(b_desc, [k, off_n], bar, b)
mbarrier.wait(bar, phase=phase)
phase ^= 1
# Since `warpgroup_mma_wait` is a no-op when there are no WGMMAs in
# flight, we can overlap the WGMMA by waiting first, then issuing the
# async WGMMA.
acc = warpgroup_mma_wait(num_outstanding=0, deps=(acc, ))
acc = warpgroup_mma(a, b, acc, is_async=True)
# Move to the next buffer. The TMA load will start while the WGMMA is
# still running.
index ^= 1
# Wait for the last WGMMA to complete.
acc = warpgroup_mma_wait(num_outstanding=0, deps=(acc, ))
mbarrier.invalidate(bar)
c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout)
c_smem.store(acc.to(dtype))
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [off_m, off_n], c_smem)
tma.store_wait(pendings=0)
def blocked_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_warps):
M, N = C.shape
a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16)
b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], gl.float16)
c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16)
a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout)
b_desc = TensorDescriptor.from_tensor(B, [BLOCK_K, BLOCK_N], b_layout)
c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout)
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
blocked_matmul_pipelined_kernel[grid](a_desc, b_desc, c_desc, num_warps=num_warps)
@pytest.mark.parametrize("M, N, K", [(208, 416, 304), (2000, 1000, 2000)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(64, 64, 64), (128, 128, 128)])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper")
def test_blocked_matmul_pipelined(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_warps):
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
blocked_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_warps)
torch.testing.assert_close(A @ B, C, rtol=1e-3, atol=1e-1)
Search for another set of configs. Apply simiar principles to prune down the potential configs. Our previous best block config will use 160 KB of smem, too much for an occupancy of 2, but leaves performance on the table by not using the remaining 68 KB. It’s likely the best kernel reduces BLOCK_N in favour of keeping 2 occupancy.
if __name__ == "__main__":
print("Benchmarking pipelined matmul")
print("=============================")
configs = find_configs(occupancy=1, dtype=torch.float16, num_buffers=2)
configs += find_configs(occupancy=2, dtype=torch.float16, num_buffers=2)
# Add our previous best config since it doesn't get selected.
configs.append([64, 256, 128, 4, 256, 2])
print("BLOCK_M BLOCK_N BLOCK_K num_warps instr_shape_n occupancy time (ms) tflops/s")
for BLOCK_M, BLOCK_N, BLOCK_K, num_warps, instr_shape_n, occupancy in configs:
fn = lambda: blocked_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_warps)
ms = triton.testing.do_bench(fn)
flops = 2 * M * N * K
tflops_per_sec = flops * 1e-12 / (ms * 1e-3)
print(f"{BLOCK_M:>7} {BLOCK_N:>7} {BLOCK_K:>7} {num_warps:>9} {instr_shape_n:>13} "
f"{occupancy:>9} {ms:>9.2f} {tflops_per_sec:>8.2f}")
print()
BLOCK_M BLOCK_N BLOCK_K num_warps instr_shape_n occupancy time (ms) tflops/s
128 256 128 8 256 1 5.16 426.06
256 128 128 8 128 1 5.70 385.85
64 256 64 4 256 2 5.27 417.50
64 128 128 4 128 2 5.71 384.98
128 128 64 4 128 2 4.44 495.31
128 128 64 8 128 2 4.92 446.81
64 256 128 4 256 2 6.05 363.36
We see indeed that the best config ends up with instr_shape_n=128. Note that our previous best config is over 100 TFLOPS slower now! Pipelining the WGMMA delivers a modest 5% speedup overall, but we had to re-tune the hyperparameters.
Pipelining both the async TMA loads and the WGMMA is left as an exercise to the reader.
Main takeaways:
WGMMA is a Hopper-specific instruction that performs block-level MMA.
WGMMA is asynchronous and can be overlapped with other operations.
WGMMA has a bunch of restrictions on its layout.
LHS operand can be in shared memory or registers.
WGMMA can handle transposed inputs, and we can create transposed views.
Pipelining the WGMMA leads to better performance by enabling overlap.
Hyperparameter tuning is critical for performance.