Multi-CTA
In Hopper, NVIDIA added a new thread group level to the hierarchy, the CGA. A CGA is a group of up to 16 CTAs that may collaborate with each other. In particular they can:
Load data from HBM collaboratively via TMA broadcasting
Exchange data by accessing each other’s shared memory. This is often called “using distributed shared memory”
Starting in Blackwell, pairs CTA can collaboratively compute the result of a matrix multiplication
Subsets of the CGA cluster can be selectively synchronized via
mbarriers
Of course, different CTAs may or may not be allocated to the same SM (in fact the documentation does not provide any guarantees about this) so operations like synchronisation or accessing each other’s shared memory are much more costly than accessing shared memory or synchronising the threads within a single CTA. As such, when using CGAs, the name of the game is to maximise the collaboration while not introducing unnecessary synchronisation points.
Multi-CTA Layouts
Layouts can be sharded across CTAs in a natural way. For example, we can have a blocked layout on a program with 4 warps and 2 CTAs of the form:
gl.BlockedLayout([1, 1], [1, 32], [1, 4], [1, 0], cga_layout=[[1, 0]])
The cga_layout representation [[1, 0]] denotes a linear layout. In this case, it
denotes that the two CTAs are sharding the tensor along the 0th dimension into two
contiguous subtensors.
Similarly, if we had 8 CTAs and we wanted to shard a shared memory descriptor across the 0th dimension using the first 4 CTAs and then across the 1st dimension using the last 4 CTAs, the layout could look like
gl.NVMMASharedLayout.get_default_for([M, N], gl.float16, cga_layout=[[1, 0], [2, 0], [0, 1]])
The cga_layout will always have log2(numCTAs) bases, and it will always denote sharding
the full tensor into contiguous chunks. For more sharding patterns where the CTAs may not
shard the tensor into contiguous subtensors, like
| CTA0 warp0 | | CTA1 warp0 | | CTA0 warp1 | | CTA1 warp1 |
one may use the layouts LinearEncoding for data in registers and SharedLinearEncoding
for data in shared memory. In these cases, rather than having an attribute called CGALayout
the CGA layout is encoded as part of the LinearLayout under the input dimension named block.
The example above would then look like:
gl.LinearEncoding(warps=[[2]], block=[[1]])
as we shard first along the CTAs and then along the warps.
import importlib
import pytest
import torch
import triton
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.language.nvidia.blackwell import (
TensorMemoryLayout,
allocate_tensor_memory,
clc,
tcgen05_commit,
tcgen05_mma,
tcgen05_mma_barrier_count,
tensor_memory_descriptor,
)
from triton.experimental.gluon.language.nvidia.hopper import mbarrier, tma
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
# Re-use baseline tutorials for comparisons.
t8 = importlib.import_module("08-warp-specialization")
def is_hopper_or_newer():
if not torch.cuda.is_available():
return False
target = triton.runtime.driver.active.get_current_target()
return target.backend == "cuda" and torch.cuda.get_device_capability()[0] >= 9
def is_blackwell():
if not torch.cuda.is_available():
return False
target = triton.runtime.driver.active.get_current_target()
return target.backend == "cuda" and torch.cuda.get_device_capability()[0] == 10
if __name__ == "__main__" and not is_blackwell():
raise RuntimeError("This tutorial requires a Blackwell NVIDIA GPU")
def tflops(ms, M, N, K):
return 2 * M * N * K * 1e-12 / (ms * 1e-3)
def gbps(ms, num_bytes):
return num_bytes * 1e-9 / (ms * 1e-3)
def pick_multicta_softmax_config(n_cols):
warp_thresholds = [(3072, 1), (6144, 2)]
cluster_thresholds = [(16 * 1024, 1), (32 * 1024, 2), (64 * 1024, 4), (128 * 1024, 8)]
num_warps = next((v for limit, v in warp_thresholds if n_cols <= limit), 4)
cluster_n = next((v for limit, v in cluster_thresholds if n_cols <= limit), 16)
return {
"num_warps": num_warps,
"num_ctas": cluster_n,
}
A multi-CTA kernel is launched with num_ctas > 1 where num_ctas is a power of two
and num_ctas <= 16.
Layout-driven operations such as gl.convert_layout, gl.reduce and gl.sum use
clusters automatically when the source and destination layouts shard the CTA dimension
differently.
The kernel below shards one row across multiple CTAs and uses the automatic
cross-CTA reductions in gl.max and gl.sum to compute a numerically stable
row-wise softmax.
Without CGAs, we would need to switch to an iterative reduction or a multi-kernel approach once the row becomes too wide for a single CTA.
@gluon.jit
def multicta_softmax_kernel(
x_ptr,
out_ptr,
x_row_stride,
out_row_stride,
BLOCK_N: gl.constexpr,
):
pid = gl.program_id(0)
cga_layout: gl.constexpr = ((1, ), (2, ), (4, ), (8, ), (16, ))[:gl.num_ctas().bit_length() - 1]
layout: gl.constexpr = gl.BlockedLayout([4], [32], [gl.num_warps()], [0], cga_layout=cga_layout)
offs_n = gl.arange(0, BLOCK_N, layout)
mask = offs_n < BLOCK_N
row_start = pid * x_row_stride
out_row_start = pid * out_row_stride
x = gl.load(x_ptr + row_start + offs_n, mask=mask, other=-float("inf"))
row_max = gl.max(x, axis=0)
y = gl.exp(x - row_max)
row_sum = gl.sum(y, axis=0)
z = y * (1.0 / row_sum)
gl.store(out_ptr + out_row_start + offs_n, z, mask=mask)
def multicta_softmax_f32(x, out=None):
M, N = x.shape
cfg = pick_multicta_softmax_config(N)
if out is None:
out = torch.empty_like(x)
multicta_softmax_kernel[(M, )](
x,
out,
x.stride(0),
out.stride(0),
BLOCK_N=N,
num_warps=cfg["num_warps"],
num_ctas=cfg["num_ctas"],
)
return out
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer")
@pytest.mark.parametrize("M, N", [(64, 64), (64, 256), (16, 2**16)])
def test_multicta_softmax_f32(M, N):
x = torch.randn((M, N), device="cuda", dtype=torch.float32)
out = multicta_softmax_f32(x)
ref = torch.softmax(x, dim=1)
torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5)
def benchmark_multicta_softmax_f32():
if not is_hopper_or_newer():
raise RuntimeError("softmax benchmark requires Hopper or newer")
SOFTMAX_BENCH_SHAPES = [
(2**15, 2**8),
(2**15, 2**9),
(2**15, 2**10),
(2**15, 2**11),
(2**15, 2**12),
(2**15, 2**13),
(2**15, 2**14),
(2**15, 2**15),
(2**15, 2**16),
(2**14, 2**17),
(2**13, 2**18),
]
print("Benchmarking multicta_softmax")
print("============================")
print(" shape CTAs warps time (ms) bandwidth (GB/s)")
for M, N in SOFTMAX_BENCH_SHAPES:
cfg = pick_multicta_softmax_config(N)
x = torch.empty((M, N), device="cuda", dtype=torch.float32).uniform_(-1, 1)
out = torch.empty_like(x)
ms = triton.testing.do_bench_cudagraph(lambda: multicta_softmax_f32(x, out))
num_bytes = 2 * x.numel() * x.element_size()
print(f"{M:>6} x {N:<6} {cfg['num_ctas']:>4} {cfg['num_warps']:>5} {ms:>9.3f} {gbps(ms, num_bytes):>16.2f}")
benchmark_multicta_softmax_f32()
Softmax benchmark results
Benchmarking multicta_softmax
============================
shape CTAs warps time (ms) bandwidth (GB/s)
32768 x 256 1 1 0.018 3661.46
32768 x 512 1 1 0.020 6746.45
32768 x 1024 1 1 0.040 6740.50
32768 x 2048 1 1 0.078 6920.01
32768 x 4096 1 2 0.152 7065.25
32768 x 8192 1 4 0.301 7136.76
32768 x 16384 1 4 0.600 7157.74
32768 x 32768 2 4 1.312 6545.11
32768 x 65536 4 4 2.836 6057.26
16384 x 131072 8 4 3.142 5468.66
8192 x 262144 16 4 3.627 4736.15
We see that here using multiCTA we are able to get very good performance across the board.
Multi-CTA synchronization
Since CTAs may be on different SMs, sychronization is much slower than within a CTA. As such, gluon provides a rather conservative automatic synchronization guarantee, and the user is responsible for the rest of the synchronization.
Gluon will place synchronisaton primitives between operations like gl.convert_layout, gl.reduce and gl.sum when the source and destination layouts shard the CTA dimension differently. All the other operations like TMA, WGMMA, TCGen5MMA, etc. should be synchronized by the kernel writer via mbarriers, same as it’s done for single-CTA kernels.
The semantics of the cga_layout for a multi-CTA mbarrier are slightly different: As discussed in 02-layouts.py, a linear layout represents a map from F_2^n to F_2^m. In this case, the cga_layout is a map from the numCTAs (which is a power of two) to the number of barriers it represents. For example, we could have a mbarrier layout where each CTA has its own barrier.
num_ctas: gl.constexpr = 4
bar = gl.allocate_shared_memory(gl.int64, [num_ctas], MBarrierLayout(cga_layout=[[1], [2], [4]]))
So, we define the cga_layout matrix by its columns (in binary), which represents the 3x3 identity matrix. Since this pattern is so common, gluon provides a helper function to create it:
bar = mbarrier.allocate_mbarrier()
Now, barrier layouts also allow for cross-CTA synchronization. For example, we an define a 2-CTA mbarrier for an 8CTA kernel as:
bar = gl.allocate_shared_memory(gl.int64, [4], MBarrierLayout(cga_layout=[[0], [1], [2]]))
Note that now the non-zero bases are just [1] and [2], so there are just 22 = 4 barriers. Since it’s an 8 CTA kernel, there are 23 = 8 bases though. The layout now has broadcasting on the 0th column. What that means is that any CTA that just differs on the 0th bit will share a barrier. For example, CTA0 and CTA1 will share a barrier, CTA2 and CTA3 will and so on. The lead CTA is the smallest CTA id in the group. For this layout, the even CTA IDs are the lead CTAs.
In general, an mbarrier cta_layout is a sequence [[2**i] for i in range(k)] for k <= log2(num_ctas)
with log2(num_ctas) - k zeros interleaved.
All the operations that act on barriers generalize naturally to multi-CTA barriers. More explicitly:
mbarrier.initmultiplies the count argument by the number of CTAs in the group and it’s only initialized on the lead CTAmbarrier.expectmultiplies the size_per_cta argument by the number of CTAs in the group and it’s only expected on the lead CTA since an expect counts as one arrival, all the non-lead CTAs will also emit one arrival to the lead CTA.mbarrier.arriveevery CTA in a group arrives on the lead CTAmbarrier.waitjust the lead CTA waits for the barrier
Final note on synchronization. cluster.arrive / cluster.wait (i.e., CGA barriers, the cluster equivalent of bar.sync for CTAs) must be executed by all threads in the kernel. As a result, they cannot be used inside a warp_specialize block.
Moreover, operations such as convert_layout, reduce, sum, max, etc., emit CGA barriers when they cross CTAs. Therefore, these operations are also not allowed inside a warp_specialize block whenever they may span multiple CTAs.
2CTA TCGen5MMA
In 2CTA mode, the tcgen05_mma instruction uses data from every other CTA in a pair (i.e. CTA0 and CTA1, CTA2 and CTA3, etc.) to compute the result. In mathematical terms, it computes the outer product of the two operands, where the LHS holds the input sharded along the M dimension and the RHS holds the input sharded along the N dimension. In terms of cga_layouts, the LHS has its first basis equal to (1, 0) and the RHS has its first basis equal to (0, 1). The accumulator is also shared as (1, 0)
Same as for single-CTA, the blockM shape of TensorMemoryLayout is the shape
of the instruction, which can be either 64 or 128.
In Gluon, 2CTA mode is selected on the accumulator layout via
TensorMemoryLayout(..., two_ctas=True).
If one tcgen05_mma instruction uses 2CTA mode, the kernel is declared as using 2CTA mode. In this case, all the other tcgen05_mma instructions in the kernel must use 2CTA mode.
The mma_bar itself does not need two_ctas=True. It is a regular
multi-CTA barrier, and tcgen05_mma will multicast its completion signal to
the two CTAs in the pair. The TMA hand-off barrier does need
two_ctas=True, because only the lead CTA waits before issuing the MMA.
Once one tcgen05_mma in a kernel uses 2CTA mode, all of the tcgen05_mma
instructions in that kernel must use 2CTA mode.
The tcgen05_mma instruction is issued from the lead CTA in each pair. As such,
when used in conjunction with TMA, the TMA barrier needs to be two_ctas=True.
What this does is that it creates a barrier with cga_layout[0] = [0], which means
that CTA0 will wait for both its data and the data from CTA1 to be loaded before
issuing the MMA.
The kernel two_cta_tcgen05_kernel shows the 2CTA TCGen5MMA pattern on a single tile.
It’s worth noting that the pattern changes a bit once the TMA has to wait on the tcgen05_mma We will cover this in the next section.
@gluon.jit
def two_cta_tcgen05_kernel(a_desc, b_desc, c_desc):
gl.static_assert(gl.num_ctas() == 2)
cluster_m: gl.constexpr = a_desc.block_shape[0]
tile_n: gl.constexpr = b_desc.block_shape[1]
cta_m: gl.constexpr = cluster_m // 2
cga_layout: gl.constexpr = c_desc.layout.cga_layout
smem_a = gl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout)
smem_b = gl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout)
tma_bar = mbarrier.allocate_mbarrier(two_ctas=True)
mma_bar = mbarrier.allocate_mbarrier()
mbarrier.init(tma_bar, count=1)
mbarrier.init(mma_bar, count=1)
mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta)
tma.async_load(a_desc, [0, 0], tma_bar, smem_a)
tma.async_load(b_desc, [0, 0], tma_bar, smem_b)
mbarrier.wait(tma_bar, phase=0, deps=[smem_a, smem_b])
mbarrier.invalidate(tma_bar)
acc_layout: gl.constexpr = TensorMemoryLayout(
block=(cta_m, tile_n),
col_stride=1,
cga_layout=cga_layout,
two_ctas=True,
)
acc = allocate_tensor_memory(gl.float32, [cluster_m, tile_n], acc_layout)
tcgen05_mma(smem_a, smem_b, acc, use_acc=False, mbarriers=[mma_bar])
mbarrier.wait(mma_bar, phase=0, deps=[smem_a, smem_b])
mbarrier.invalidate(mma_bar)
c_smem = gl.allocate_shared_memory(c_desc.dtype, c_desc.block_shape, c_desc.layout)
c_smem.store(acc.load().to(c_desc.dtype))
tma.async_copy_shared_to_global(c_desc, [0, 0], c_smem)
def run_two_cta_tcgen05(a, b, c):
M, N, K = a.shape[0], b.shape[1], a.shape[1]
a_layout = gl.NVMMASharedLayout.get_default_for([M, K], gl.float16, cga_layout=[(1, 0)])
b_layout = gl.NVMMASharedLayout.get_default_for([K, N], gl.float16, cga_layout=[(0, 1)])
c_layout = gl.NVMMASharedLayout.get_default_for([M, N], gl.float16, cga_layout=[(1, 0)])
a_desc = TensorDescriptor.from_tensor(a, [M, K], a_layout)
b_desc = TensorDescriptor.from_tensor(b, [K, N], b_layout)
c_desc = TensorDescriptor.from_tensor(c, [M, N], c_layout)
two_cta_tcgen05_kernel[(1, )](a_desc, b_desc, c_desc, num_warps=4, num_ctas=2)
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_two_cta_tcgen05():
M, N, K = 256, 128, 64
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
c = torch.empty((M, N), device="cuda", dtype=torch.float16)
run_two_cta_tcgen05(a, b, c)
torch.testing.assert_close(c, torch.matmul(a, b), atol=1e-1, rtol=1e-2)
There are a few things that change from the single-CTA case:
The TMA barrier is
two_ctas=True. This was covered in the previous section. - Thembarrier.expectis called with the total byte count per CTA, not the whole block.The tcgen05_mma TMEM layout is now
two_ctas=True
Note that there will be a few more changes once this is used in a for-loop and/or with TMA with multicast. More on this in the next section.
TMA with multicast
Since Hopper onwards, TMA has the ability to multicast data to multiple CTAs. This is useful in multi-CTA kernels in Hopper as wgmma does not have a 2CTA mode, or on Blackwell+ kernels when using more than 2 CTAs.
In this case, for a cga_layout for the accumulator, we may compute the layouts
for A and B as follows:
# Example cga_layout
cga_layout = [(1, 0), (2, 0), (0, 1)]
def get_cga_layout(layout, op_idx, two_ctas):
assert op_idx in (0, 1)
if not layout:
return layout
# Broadcast along K (the reduction dimension)
# We multiply by 2 for op_idx == 1, as we have added the (0, 1) basis.
def broadcast(b):
mul = 2 if two_ctas else 1
return (b[0], 0) if op_idx == 0 else (0, mul * b[1])
if not two_ctas:
return tuple(map(broadcast, layout))
# 2CTA performs an outer product so bases are [1, 0] and [0, 1]
assert layout[0] == (1, 0)
first = (1, 0) if op_idx == 0 else (0, 1)
return (first, *map(broadcast, layout[1:]))
cga_layout_a = get_cga_layout(cga_layout, 0, two_ctas=False)
cga_layout_b = get_cga_layout(cga_layout, 1, two_ctas=False)
In other words, the cga_layout of A and B is that of C zeroing out the inner dimension for each.
This means that some bases are zero for A and/or B, so different CTAs will load the same data. Multicast will allow these CTAs to hit the L2 cache efficiently.
The synchronization pattern is the same as for a regular TMA load:
initialize a barrier,
expectthe byte count,issue the TMA with
multicast=True,wait on the barrier.
The reason this works is that the TMA instruction broadcasts its arrival to every CTA in the multicast group atomically, so the wait side does not need a different API.
The TMA destination must use a broadcast cga_layout, so that both CTAs
receive the same shared-memory tile. The barrier stays a regular 1D TMA
barrier unless the kernel is in 2CTA mode.
The example below keeps things intentionally simple: it multicasts one tile into shared memory and then materializes that same tile back to global memory.
@gluon.jit
def tma_multicast_copy_kernel(in_desc, out_desc):
gl.static_assert(gl.num_ctas() == 2)
smem = gl.allocate_shared_memory(in_desc.dtype, in_desc.block_shape, in_desc.layout)
# This kernel is not in 2CTA mode, so the TMA barrier is per-CTA.
bar = mbarrier.allocate_mbarrier()
mbarrier.init(bar, count=1)
mbarrier.expect(bar, in_desc.nbytes_per_cta)
tma.async_load(in_desc, [0, 0], bar, smem, multicast=True)
mbarrier.wait(bar, phase=0, deps=[smem])
tma.async_copy_shared_to_global(out_desc, [0, 0], smem)
def run_tma_multicast_copy(inp, out):
layout = gl.NVMMASharedLayout.get_default_for(inp.shape, gl.float16, cga_layout=[[0, 0]])
in_desc = TensorDescriptor.from_tensor(inp, inp.shape, layout)
out_desc = TensorDescriptor.from_tensor(out, inp.shape, layout)
tma_multicast_copy_kernel[(1, )](in_desc, out_desc, num_warps=4, num_ctas=2)
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer")
def test_tma_multicast_copy():
M, N = 128, 128
inp = torch.randn((M, N), device="cuda", dtype=torch.float16)
out = torch.empty_like(inp)
run_tma_multicast_copy(inp, out)
torch.testing.assert_close(out, inp, atol=0, rtol=0)
TMA into MMA in a loop
Here we illustrate the fully generic approach to mixing TMA with (or without) multicast into a tcgen05_mma pipeline.
In this case, the tcgen05_mma instruction needs to wait for all the CTAs in its multicast group to complete before it can continue the next iteration, as otherwise the next iteration’s TMA loads will overwrite the data from the previous iteration before it has finished consuming it.
As such, in this case, we need to use tcgen05_mma_barrier_count to compute the number of CTAs
in a multicast group. Similarly we set the multicast=True flag on the tcgen05_mma instruction
to note that it will have to wait for the multicast group to complete before it can continue.
These functions are generic, so a pattern of this form would work also for non-multicast kernels or non-2CTA kernels.
@gluon.jit
def tma_tcgen05_kernel(a_desc, b_desc, out_desc, NUM_K_TILES: gl.constexpr, acc_tmem_layout: gl.constexpr):
block_m: gl.constexpr = a_desc.block_shape[0]
block_k: gl.constexpr = a_desc.block_shape[1]
block_n: gl.constexpr = b_desc.block_shape[1]
smem_a = gl.allocate_shared_memory(a_desc.dtype, a_desc.block_shape, a_desc.layout)
smem_b = gl.allocate_shared_memory(b_desc.dtype, b_desc.block_shape, b_desc.layout)
tma_bar = mbarrier.allocate_mbarrier(two_ctas=True)
mma_bar = mbarrier.allocate_mbarrier()
mbarrier.init(tma_bar, count=1)
mbarrier.init(mma_bar, count=tcgen05_mma_barrier_count([smem_a, smem_b], multicast=True))
acc_tmem = allocate_tensor_memory(gl.float32, [block_m, block_n], acc_tmem_layout)
phase_tma = 0
phase_mma = 0
for k in range(NUM_K_TILES):
mbarrier.expect(tma_bar, a_desc.nbytes_per_cta + b_desc.nbytes_per_cta)
tma.async_load(a_desc, [0, k * block_k], tma_bar, smem_a, multicast=True)
tma.async_load(b_desc, [k * block_k, 0], tma_bar, smem_b, multicast=True)
mbarrier.wait(tma_bar, phase=phase_tma, deps=[smem_a, smem_b])
phase_tma ^= 1
tcgen05_mma(smem_a, smem_b, acc_tmem, use_acc=(k != 0), multicast=True, mbarriers=[mma_bar])
mbarrier.wait(mma_bar, phase=phase_mma, deps=[smem_a, smem_b])
phase_mma ^= 1
mbarrier.invalidate(tma_bar)
mbarrier.invalidate(mma_bar)
out_smem = gl.allocate_shared_memory(out_desc.dtype, out_desc.block_shape, out_desc.layout)
out_smem.store(acc_tmem.load().to(out_desc.dtype))
tma.async_copy_shared_to_global(out_desc, [0, 0], out_smem)
def tma_tcgen05_example(a, b):
BLOCK_M = 512
BLOCK_N = 128
BLOCK_K = 64
NUM_K_TILES = 2
cga_layout_a = ((1, 0), (2, 0))
cga_layout_b = ((0, 1), (0, 0))
cga_layout_c = ((1, 0), (2, 0))
M, K = a.shape
Kb, N = b.shape
if K != Kb:
raise ValueError(f"inner dimensions must match, got {K} and {Kb}")
if M != BLOCK_M or N != BLOCK_N or K != BLOCK_K * NUM_K_TILES:
raise ValueError(f"expected shapes {(BLOCK_M, BLOCK_K * NUM_K_TILES)} x "
f"{(BLOCK_K * NUM_K_TILES, BLOCK_N)}, got {tuple(a.shape)} x {tuple(b.shape)}")
out = torch.empty((M, N), device="cuda", dtype=torch.float16)
a_layout = gl.NVMMASharedLayout.get_default_for([M, BLOCK_K], gl.float16, cga_layout=cga_layout_a)
b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, N], gl.float16, cga_layout=cga_layout_b)
c_layout = gl.NVMMASharedLayout.get_default_for([M, N], gl.float16, cga_layout=cga_layout_c)
acc_tmem_layout = TensorMemoryLayout(block=(128, N), col_stride=1, cga_layout=cga_layout_c, two_ctas=True)
a_desc = TensorDescriptor.from_tensor(a, [M, BLOCK_K], a_layout)
b_desc = TensorDescriptor.from_tensor(b, [BLOCK_K, N], b_layout)
c_desc = TensorDescriptor.from_tensor(out, [M, N], c_layout)
tma_tcgen05_kernel[(1, )](
a_desc,
b_desc,
c_desc,
NUM_K_TILES,
acc_tmem_layout,
num_warps=4,
num_ctas=4,
)
return out
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_tma_tcgen05():
M = 512
N = 128
K = 128
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
out = tma_tcgen05_example(a, b)
torch.testing.assert_close(out, torch.matmul(a, b), atol=1e-1, rtol=1e-2)
A Speed Of Light matmul kernel
Here we illustrate a fully generic approach to writing a matmul kernel that uses TMA (perhaps with multicast) into warp-specialized tcgen05_mma pipeline.
For this example, we generalise the CLC ideas presented in 12-cluster-launch-control.py to a warp-specialized kernel by adding a new partition that handles the CLC generation and broadcasts this to all CTAs
We also use an extra helper called _planar_snake to swizzle the program id’s to improve L2 locality.
Counter = t8.Counter
cublas = t8.cublas
@gluon.constexpr_function
def get_split_dim(cga_layout, dim):
return 1 << sum(b[dim] != 0 for b in cga_layout)
@gluon.jit
def _planar_snake(lin_idx, m_tiles, n_tiles, minor_dim: gl.constexpr, tile_width: gl.constexpr):
major_size = n_tiles if minor_dim == 0 else m_tiles
minor_size = m_tiles if minor_dim == 0 else n_tiles
full_minor_tiles = minor_size // tile_width
full_minor_size = full_minor_tiles * tile_width
full_elements = full_minor_tiles * tile_width * major_size
minor_tile_idx = lin_idx // (tile_width * major_size)
full_minor_within = lin_idx % tile_width
full_major_within = (lin_idx // tile_width) % major_size
full_minor = minor_tile_idx * tile_width + full_minor_within
full_major = gl.where((minor_tile_idx % 2) == 0, full_major_within, major_size - 1 - full_major_within)
partial_width = minor_size - full_minor_size
partial_width = gl.where(partial_width > 0, partial_width, 1)
partial_lin = lin_idx - full_elements
partial_minor_within = partial_lin % partial_width
partial_major_within = (partial_lin // partial_width) % major_size
partial_minor = minor_tile_idx * tile_width + partial_minor_within
partial_major = gl.where((minor_tile_idx % 2) == 0, partial_major_within, major_size - 1 - partial_major_within)
in_full_tile = lin_idx < full_elements
minor = gl.where(in_full_tile, full_minor, partial_minor)
major = gl.where(in_full_tile, full_major, partial_major)
if minor_dim == 0:
return minor, major
return major, minor
@gluon.aggregate
class ClcTileSchedulerConsumer:
has_work: gl.tensor
tile_id: gl.tensor
pid_m: gl.tensor
pid_n: gl.tensor
num_pid_m: gl.tensor
num_pid_n: gl.tensor
TILE_M: gl.constexpr
TILE_N: gl.constexpr
MINOR_DIM: gl.constexpr
GRID_TILE_WIDTH: gl.constexpr
clc_result_buffers: gl.shared_memory_descriptor
clc_barriers: gl.shared_memory_descriptor
clc_planar_pid_buffers: gl.shared_memory_descriptor
clc_planar_ready_bars: gl.shared_memory_descriptor
clc_consumed_bars: gl.shared_memory_descriptor
counter: Counter
consumed_counter: Counter
@gluon.jit
def initialize(M, N, TILE_M: gl.constexpr, TILE_N: gl.constexpr, MINOR_DIM: gl.constexpr,
GRID_TILE_WIDTH: gl.constexpr, clc_result_buffers, clc_barriers, clc_planar_pid_buffers,
clc_planar_ready_bars, clc_consumed_bars):
tile_id = gl.program_id(axis=0)
num_pid_m = gl.cdiv(M, TILE_M)
num_pid_n = gl.cdiv(N, TILE_N)
pid_m, pid_n = _planar_snake(tile_id, num_pid_m, num_pid_n, MINOR_DIM, GRID_TILE_WIDTH)
return ClcTileSchedulerConsumer(
gl.to_tensor(True),
tile_id,
pid_m,
pid_n,
num_pid_m,
num_pid_n,
TILE_M,
TILE_N,
MINOR_DIM,
GRID_TILE_WIDTH,
clc_result_buffers,
clc_barriers,
clc_planar_pid_buffers,
clc_planar_ready_bars,
clc_consumed_bars,
Counter.create(0, clc_barriers.shape[0]),
Counter.create(0, clc_barriers.shape[0]),
)
@gluon.jit
def get_offsets(self):
return self.pid_m * self.TILE_M, self.pid_n * self.TILE_N
@gluon.jit
def step(self, iteration):
consumed_counter = self.consumed_counter
if iteration > 0:
mbarrier.arrive(self.clc_consumed_bars.index(consumed_counter.index))
consumed_counter = consumed_counter.next()
counter = self.counter
barrier = self.clc_barriers.index(counter.index)
result = self.clc_result_buffers.index(counter.index)
mbarrier.wait(barrier, counter.phase)
clc_res = clc.load_result(result)
mbarrier.wait(self.clc_planar_ready_bars.index(counter.index), counter.phase)
planar_slot = self.clc_planar_pid_buffers.index(counter.index)
planar_layout: gl.constexpr = gl.BlockedLayout([1], [32], [gl.num_warps()], [0],
[[0]] * (gl.num_ctas().bit_length() - 1))
packed_pid = planar_slot.load(planar_layout).reshape([])
pid_m = ((packed_pid >> 32) & 0xFFFFFFFF).to(gl.int32)
pid_n = (packed_pid & 0xFFFFFFFF).to(gl.int32)
has_work = clc_res.is_canceled()
tile_id = self.tile_id
if has_work:
tile_id = clc_res.program_id(0)
return ClcTileSchedulerConsumer(
has_work,
tile_id,
pid_m,
pid_n,
self.num_pid_m,
self.num_pid_n,
self.TILE_M,
self.TILE_N,
self.MINOR_DIM,
self.GRID_TILE_WIDTH,
self.clc_result_buffers,
self.clc_barriers,
self.clc_planar_pid_buffers,
self.clc_planar_ready_bars,
self.clc_consumed_bars,
counter.next(),
consumed_counter,
)
@gluon.aggregate
class MatmulPartitionArgs:
a_desc: tma.tensor_descriptor
b_desc: tma.tensor_descriptor
c_desc: tma.tensor_descriptor
a_bufs: gl.shared_memory_descriptor
b_bufs: gl.shared_memory_descriptor
load_empty_bars: gl.shared_memory_descriptor
load_ready_bars: gl.shared_memory_descriptor
acc_bufs: tensor_memory_descriptor
acc_empty_bars: gl.shared_memory_descriptor
acc_ready_bars: gl.shared_memory_descriptor
clc_result_buffers: gl.shared_memory_descriptor
clc_barriers: gl.shared_memory_descriptor
clc_planar_pid_buffers: gl.shared_memory_descriptor
clc_planar_ready_bars: gl.shared_memory_descriptor
clc_consumed_bars: gl.shared_memory_descriptor
MINOR_DIM: gl.constexpr
GRID_TILE_WIDTH: gl.constexpr
SUBTILE_STAGES: gl.constexpr
@gluon.jit
def get_clc_consumer(self):
return ClcTileSchedulerConsumer.initialize(
self.c_desc.shape[0],
self.c_desc.shape[1],
self.a_desc.block_shape[0],
self.b_desc.block_shape[1],
self.MINOR_DIM,
self.GRID_TILE_WIDTH,
self.clc_result_buffers,
self.clc_barriers,
self.clc_planar_pid_buffers,
self.clc_planar_ready_bars,
self.clc_consumed_bars,
)
@gluon.jit
def matmul_clc_partition(p):
tile_m: gl.constexpr = p.a_desc.block_shape[0]
tile_n: gl.constexpr = p.b_desc.block_shape[1]
has_work = gl.to_tensor(True)
num_pid_m = gl.cdiv(p.c_desc.shape[0], tile_m)
num_pid_n = gl.cdiv(p.c_desc.shape[1], tile_n)
state = Counter.create(0, p.clc_barriers.shape[0])
consumed_state = Counter.create(1, p.clc_barriers.shape[0])
acc_stages: gl.constexpr = p.clc_barriers.shape[0]
i = 0
while has_work:
mbarrier.wait(p.clc_consumed_bars.index(consumed_state.index), consumed_state.phase, pred=(i >= acc_stages))
barrier = p.clc_barriers.index(state.index)
result = p.clc_result_buffers.index(state.index)
mbarrier.expect(barrier, 16)
clc.try_cancel(result, barrier)
mbarrier.wait(barrier, state.phase)
clc_res = clc.load_result(result)
has_work = clc_res.is_canceled()
pid_m = gl.to_tensor(0)
pid_n = gl.to_tensor(0)
if has_work:
tile_id = clc_res.program_id(0)
pid_m, pid_n = _planar_snake(tile_id, num_pid_m, num_pid_n, p.MINOR_DIM, p.GRID_TILE_WIDTH)
packed_pid = (pid_m.to(gl.int64) << 32) | (pid_n.to(gl.int64) & 0xFFFFFFFF)
planar_slot = p.clc_planar_pid_buffers.index(state.index)
planar_layout: gl.constexpr = gl.BlockedLayout([1], [32], [gl.num_warps()], [0],
[[0]] * (gl.num_ctas().bit_length() - 1))
planar_slot.store(gl.full([1], packed_pid, gl.int64, layout=planar_layout))
mbarrier.arrive(p.clc_planar_ready_bars.index(state.index))
state = state.next()
consumed_state = consumed_state.next()
i += 1
@gluon.jit
def matmul_load_partition(p):
block_k: gl.constexpr = p.a_desc.block_shape[1]
K = p.a_desc.shape[1]
concurrent_loads: gl.constexpr = p.load_ready_bars.shape[0]
state = Counter.create(1, concurrent_loads)
scheduler = p.get_clc_consumer()
i = 0
while scheduler.has_work:
off_m, off_n = scheduler.get_offsets()
for k in range(0, K, block_k):
pred = (i > 0) or (k >= block_k * concurrent_loads)
mbarrier.wait(p.load_empty_bars.index(state.index), state.phase, pred=pred)
bar = p.load_ready_bars.index(state.index)
mbarrier.expect(bar, p.a_desc.nbytes_per_cta + p.b_desc.nbytes_per_cta)
tma.async_load(p.a_desc, [off_m, k], bar, p.a_bufs.index(state.index), multicast=True)
tma.async_load(p.b_desc, [k, off_n], bar, p.b_bufs.index(state.index), multicast=True)
state = state.next()
scheduler = scheduler.step(i)
i += 1
@gluon.jit
def matmul_mma_partition(p):
block_k: gl.constexpr = p.a_desc.block_shape[1]
K = p.a_desc.shape[1]
acc_stages: gl.constexpr = p.acc_empty_bars.shape[0]
load_state = Counter.create(0, p.load_empty_bars.shape[0])
acc_state = Counter.create(1, acc_stages)
scheduler = p.get_clc_consumer()
i = 0
while scheduler.has_work:
acc_buf = p.acc_bufs.index(acc_state.index)
mbarrier.wait(p.acc_empty_bars.index(acc_state.index), acc_state.phase, pred=(i >= acc_stages))
use_acc = False
for k in range(0, K, block_k):
mbarrier.wait(p.load_ready_bars.index(load_state.index), load_state.phase)
tcgen05_mma(
p.a_bufs.index(load_state.index),
p.b_bufs.index(load_state.index),
acc_buf,
use_acc=use_acc,
multicast=True,
mbarriers=[p.load_empty_bars.index(load_state.index)],
)
load_state = load_state.next()
use_acc = True
tcgen05_commit(p.acc_ready_bars.index(acc_state.index), descs=[p.a_bufs.index(0), p.b_bufs.index(0)])
acc_state = acc_state.next()
scheduler = scheduler.step(i)
i += 1
@gluon.jit
def matmul_epilogue_partition(p):
tile_m: gl.constexpr = p.a_desc.block_shape[0]
tile_n: gl.constexpr = p.b_desc.block_shape[1]
split_tile_n: gl.constexpr = p.c_desc.block_shape[1]
# Separate knobs: SUBTILE_STAGES controls shared-memory usage,
# and SUBTILE_FACTOR is the maximum number of subtiles into which we can split the tile.
subtile_factor: gl.constexpr = tile_n // split_tile_n
subtile_stages: gl.constexpr = p.SUBTILE_STAGES
acc_stages: gl.constexpr = p.acc_empty_bars.shape[0]
dtype: gl.constexpr = p.c_desc.dtype
acc_state = Counter.create(0, acc_stages)
acc_smems = gl.allocate_shared_memory(dtype, [subtile_stages, tile_m, split_tile_n], p.c_desc.layout)
sub_acc_state = Counter.create(0, subtile_stages)
scheduler = p.get_clc_consumer()
i = 0
while scheduler.has_work:
off_m, off_n = scheduler.get_offsets()
mbarrier.wait(p.acc_ready_bars.index(acc_state.index), acc_state.phase)
acc_buf = p.acc_bufs.index(acc_state.index)
for s in gl.static_range(subtile_factor):
acc_sub = acc_buf.slice(split_tile_n * s, split_tile_n)
acc_smem = acc_smems.index(sub_acc_state.index)
acc = acc_sub.load().to(dtype)
tma.store_wait(pendings=subtile_stages - 1)
acc_smem.store(acc)
tma.async_copy_shared_to_global(p.c_desc, [off_m, off_n + split_tile_n * s], acc_smem)
sub_acc_state = sub_acc_state.next()
mbarrier.arrive(p.acc_empty_bars.index(acc_state.index))
acc_state = acc_state.next()
scheduler = scheduler.step(i)
i += 1
@gluon.jit
def matmul_multicta_kernel(
a_desc,
b_desc,
c_desc,
M,
N,
K,
BLOCK_SIZE_M: gl.constexpr,
BLOCK_SIZE_N: gl.constexpr,
BLOCK_SIZE_K: gl.constexpr,
GRID_MINOR_DIM: gl.constexpr,
GRID_TILE_WIDTH: gl.constexpr,
STAGES: gl.constexpr,
ACC_STAGES: gl.constexpr,
CGA_LAYOUT: gl.constexpr,
EPILOGUE_SIZE_N: gl.constexpr,
SUBTILE_STAGES: gl.constexpr,
):
block_m: gl.constexpr = a_desc.block_shape[0]
block_n: gl.constexpr = b_desc.block_shape[1]
two_ctas: gl.constexpr = gl.num_ctas() > 1
n_partitions: gl.constexpr = 4
dtype: gl.constexpr = a_desc.dtype
a_bufs = gl.allocate_shared_memory(dtype, [STAGES] + a_desc.block_shape, a_desc.layout)
b_bufs = gl.allocate_shared_memory(dtype, [STAGES] + b_desc.block_shape, b_desc.layout)
mma_barrier_count: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], multicast=True)
load_empty_bars = mbarrier.allocate_mbarrier(batch=STAGES)
load_ready_bars = mbarrier.allocate_mbarrier(batch=STAGES, two_ctas=two_ctas)
for i in gl.static_range(STAGES):
mbarrier.init(load_empty_bars.index(i), count=mma_barrier_count)
mbarrier.init(load_ready_bars.index(i), count=1)
tmem_layout: gl.constexpr = TensorMemoryLayout(
[BLOCK_SIZE_M, block_n // get_split_dim(CGA_LAYOUT, 1)],
col_stride=1,
cga_layout=CGA_LAYOUT,
two_ctas=two_ctas,
)
acc_bufs = allocate_tensor_memory(gl.float32, [ACC_STAGES, block_m, block_n], tmem_layout)
acc_empty_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=two_ctas)
acc_ready_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES)
for i in gl.static_range(ACC_STAGES):
mbarrier.init(acc_empty_bars.index(i), count=1)
mbarrier.init(acc_ready_bars.index(i), count=mma_barrier_count)
clc_barriers = mbarrier.allocate_mbarrier(batch=ACC_STAGES)
clc_planar_ready_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES)
clc_consumed_bars = mbarrier.allocate_mbarrier(batch=ACC_STAGES, two_ctas=two_ctas)
for i in gl.static_range(ACC_STAGES):
mbarrier.init(clc_barriers.index(i), count=1)
mbarrier.init(clc_planar_ready_bars.index(i), count=1)
mbarrier.init(clc_consumed_bars.index(i), count=n_partitions - 1)
cga_layout: gl.constexpr = [[0]] * (gl.num_ctas().bit_length() - 1)
clc_layout: gl.constexpr = gl.SwizzledSharedLayout(1, 1, 1, [0], cga_layout=cga_layout)
clc_result_buffers = gl.allocate_shared_memory(
gl.int64,
[clc_barriers.shape[0], 2],
clc_layout,
)
clc_planar_pid_buffers = gl.allocate_shared_memory(gl.int64, [clc_barriers.shape[0], 1], clc_layout)
p = MatmulPartitionArgs(
a_desc,
b_desc,
c_desc,
a_bufs,
b_bufs,
load_empty_bars,
load_ready_bars,
acc_bufs,
acc_empty_bars,
acc_ready_bars,
clc_result_buffers,
clc_barriers,
clc_planar_pid_buffers,
clc_planar_ready_bars,
clc_consumed_bars,
GRID_MINOR_DIM,
GRID_TILE_WIDTH,
SUBTILE_STAGES,
)
gl.warp_specialize([
(matmul_epilogue_partition, (p, )),
(matmul_load_partition, (p, )),
(matmul_mma_partition, (p, )),
(matmul_clc_partition, (p, )),
], [1, 1, 1], [24, 24, 24])
def matmul_multicta(
a,
b,
out=None,
*,
block_size_m=128,
block_size_n=256,
block_size_k=64,
grid_minor_dim=0,
grid_tile_width=16,
stages=6,
acc_stages=2,
cga_layout=((1, 0), ),
epilogue_size_n=32,
subtile_stages=4,
):
if block_size_n // get_split_dim(cga_layout, 1) > 256:
raise ValueError(
f"cga_layout={list(cga_layout)} only supports BLOCK_SIZE_N <= {256 * get_split_dim(cga_layout, 1)}")
M, K = a.shape
K1, N = b.shape
if K != K1:
raise ValueError(f"incompatible shapes: {a.shape} and {b.shape}")
if a.dtype != torch.float16 or b.dtype != torch.float16:
raise ValueError("matmul only supports fp16 inputs")
if out is None:
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
else:
if out.shape != (M, N):
raise ValueError(f"Output has invalid shape {out.shape}, expected {(M, N)}")
c = out
tile_m = block_size_m * get_split_dim(cga_layout, 0)
two_ctas = bool(cga_layout)
a_layout = gl.NVMMASharedLayout.get_default_for([tile_m, block_size_k], gl.float16,
cga_layout=get_cga_layout(cga_layout, 0, two_ctas))
b_layout = gl.NVMMASharedLayout.get_default_for([block_size_k, block_size_n], gl.float16,
cga_layout=get_cga_layout(cga_layout, 1, two_ctas))
c_layout = gl.NVMMASharedLayout.get_default_for([tile_m, epilogue_size_n], gl.float16, cga_layout=cga_layout)
a_desc = TensorDescriptor.from_tensor(a, [tile_m, block_size_k], a_layout)
b_desc = TensorDescriptor.from_tensor(b, [block_size_k, block_size_n], b_layout)
c_desc = TensorDescriptor.from_tensor(c, [tile_m, epilogue_size_n], c_layout)
def grid(meta):
tile_m = meta["BLOCK_SIZE_M"] * get_split_dim(meta["CGA_LAYOUT"], 0)
tile_n = meta["BLOCK_SIZE_N"]
num_tiles = triton.cdiv(M, tile_m) * triton.cdiv(N, tile_n)
return (num_tiles, )
matmul_multicta_kernel[grid](
a_desc,
b_desc,
c_desc,
M,
N,
K,
block_size_m,
block_size_n,
block_size_k,
grid_minor_dim,
grid_tile_width,
stages,
acc_stages,
cga_layout,
epilogue_size_n,
subtile_stages,
num_warps=4,
num_ctas=2**len(cga_layout),
)
return c
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
def test_matmul_multicta():
M, N, K = 1024, 1024, 512
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
c = matmul_multicta(a, b)
torch.testing.assert_close(c, torch.matmul(a, b), atol=1e-1, rtol=1e-2)
if __name__ == "__main__" and is_blackwell():
print("Benchmarking matmul_multicta")
print("============================")
cfg = {
"block_size_m": 128,
"block_size_n": 256,
"block_size_k": 64,
"grid_minor_dim": 0,
"grid_tile_width": 16,
"stages": 6,
"acc_stages": 2,
"cga_layout": ((1, 0), ),
"epilogue_size_n": 32,
"subtile_stages": 4,
}
M, N = 8192, 8192
C = torch.empty((M, N), device="cuda", dtype=torch.float16)
print(" K multi-CTA cublas")
for K in [2**i for i in range(9, 15)]:
A = torch.randn((M, K), device="cuda", dtype=torch.float16)
B = torch.randn((K, N), device="cuda", dtype=torch.float16)
BT = B.T.contiguous()
r0 = tflops(triton.testing.do_bench(lambda: matmul_multicta(A, B, out=C, **cfg), warmup=200, rep=1000), M, N, K)
r1 = tflops(triton.testing.do_bench(lambda: cublas.matmul(A, BT, C), warmup=200, rep=1000), M, N, K)
print(f"{K:>5} {r0:>17.2f} {r1:>9.2f}")
Benchmarking matmul_multicta
K multi-CTA cublas
512 1096.31 1190.98 1024 1306.07 1344.48 2048 1379.80 1374.48 4096 1444.26 1431.93 8192 1302.33 1347.82 16384 1292.40 1371.82
We are able to be competitive with cublas and even beating them in quite a range of relevant Ks for this particular configuration. If we chose different configurations for different shapes we would be able to beat cublas in a wider range of shapes.