Persistent Kernels
So far, we have defined kernels such that one programs handles one block of work and we span all the work using the grid dimensions. This creates a large number of programs, and we rely on the GPU to schedule the work. The primary benefit is the GPU will dynamically load-balance the work across its SMs.
However, this approach has downsides. The scheduler incurs an overhead, and the GPU is not aware of the memory access patterns of the kernels. This also prevents overlapping across blocks of work, as the GPU waits until kernels have fully exited before issuing more work.
Persistent kernels is a technique where we assign multiple blocks of work to each program, and the programs “persist” on the GPU until all the work is complete. The work assignment is typically static, although dynamic scheduling is still possible with more advanced techniques or hardware features like cluster launch control.
In this tutorial, we will explore persistent kernels by implementing a persistent matmul. We will then show how we can pipeline across the persistent outer loop to achieve greater overlap and more throughput.
import itertools
import pytest
import torch
import triton
import importlib
import sys
from functools import partial
from typing import Union
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
from triton.experimental.gluon.language.nvidia.hopper import (
tma,
mbarrier,
fence_async_shared,
warpgroup_mma,
warpgroup_mma_wait,
warpgroup_mma_accumulator,
)
from triton.experimental.gluon.language.nvidia.blackwell import (
TensorMemoryLayout,
tensor_memory_descriptor,
allocate_tensor_memory,
tcgen05_mma,
tcgen05_commit,
)
if torch.cuda.is_available():
from triton._C.libtriton import nvidia
cublas_workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8)
cublas = nvidia.cublas.CublasLt(cublas_workspace)
else:
cublas = None
t5 = importlib.import_module("05-wgmma")
def is_hopper_or_newer():
target = triton.runtime.driver.active.get_current_target()
return target.backend == "cuda" and torch.cuda.get_device_capability()[0] >= 9
if __name__ == "__main__" and not is_hopper_or_newer():
raise RuntimeError("This tutorial requires Hopper or newer NVIDIA GPU")
profiling_with_ncu = len(sys.argv) > 1 and sys.argv[1] == "profile"
def get_flops(ms, M, N, K):
flops = 2 * M * N * K
return flops * 1e-12 / (ms * 1e-3)
In the previous two tutorials, we introduced tensor core operations for Hopper and Blackwell NVIDIA GPUs. To make this tutorial more accessible, and to demonstrate some Gluon features, we will build an abstraction around both sets of tensor core operations so that our persistent matmul can be used on both Hopper and Blackwell.
We can use @gluon.aggregate to define a class that contains the state of the matmul. We will define the API of our MMA wrapper to be like WGMMA’s, because is the more restrictive of the two.
# MMA wrapper for WGMMA, which maps directly to the WGMMA functions.
@gluon.aggregate
class WGMMA:
acc: Union[warpgroup_mma_accumulator, gl.tensor]
use_acc: gl.tensor
@gluon.jit
def initialize(dtype: gl.constexpr, BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr, num_warps: gl.constexpr):
mma_layout: gl.constexpr = t5.pick_wgmma_layout(dtype, BLOCK_M, BLOCK_N, num_warps)
acc = gl.zeros((BLOCK_M, BLOCK_N), dtype=gl.float32, layout=mma_layout)
return WGMMA(acc, gl.to_tensor(False))
@gluon.jit
def issue_async_mma(self, a, b):
acc = warpgroup_mma(a, b, self.acc, is_async=True, use_acc=self.use_acc)
# Note that aggregates don't support in-place mutation, so we need to
# return a new instance and re-assign it at the callsite.
return WGMMA(acc, gl.to_tensor(True))
@gluon.jit
def wait_num_outstanding(self, num_outstanding: gl.constexpr):
acc = warpgroup_mma_wait(num_outstanding, (self.acc, ))
return WGMMA(acc, self.use_acc)
# Take the result and reset the accumulator.
@gluon.jit
def take_result(self):
return self.acc, WGMMA(self.acc, gl.to_tensor(False))
# MMA wrapper for tcgen05. In order to implement `wait_num_outstanding`, we
# need to allocate barriers and keep track of how many MMAs have been issued.
# State will be tracked with an accumulator.
@gluon.aggregate
class MMAv5:
use_acc: gl.tensor
acc_tmem: tensor_memory_descriptor
bar: gl.shared_memory_descriptor
counter: gl.tensor
@gluon.jit
def initialize(dtype: gl.constexpr, BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr, num_warps: gl.constexpr):
layout: gl.constexpr = TensorMemoryLayout([BLOCK_M, BLOCK_N], col_stride=1)
acc_tmem = allocate_tensor_memory(gl.float32, [BLOCK_M, BLOCK_N], layout)
bar = gl.allocate_shared_memory(gl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(bar, count=1)
return MMAv5(gl.to_tensor(False), acc_tmem, bar, gl.to_tensor(0))
@gluon.jit
def issue_async_mma(self, a, b):
tcgen05_mma(a, b, self.acc_tmem, use_acc=self.use_acc)
tcgen05_commit(self.bar)
return MMAv5(gl.to_tensor(True), self.acc_tmem, self.bar, self.counter + 1)
@gluon.jit
def wait_num_outstanding(self, num_outstanding: gl.constexpr):
mbarrier.wait(self.bar, (self.counter - 1 - num_outstanding) & 1)
return self
@gluon.jit
def take_result(self):
next = MMAv5(gl.to_tensor(False), self.acc_tmem, self.bar, self.counter)
return self.acc_tmem.load(), next
def select_mma_impl():
if torch.cuda.get_device_capability()[0] == 9:
return WGMMA
elif torch.cuda.get_device_capability()[0] == 10:
return MMAv5
else:
return None
Let’s validate our abstraction by implementing a matmul where we pipeline both the MMA and the loads. This achieves async overlap of both the TMA loads and the MMAs by requiring at least two operand buffers. This will make the persistent kernel more interesting by allowing us to overlap more things.
We will factor our kernel into components we can re-use between implementations.
@gluon.jit
def issue_loads(producer, a_desc, b_desc, off_m, off_n, k, bars, a_bufs, b_bufs, num_buffers: gl.constexpr, pred=True):
index = producer % num_buffers
producer += 1
bar = bars.index(index)
mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes, pred=pred)
tma.async_load(a_desc, [off_m, k], bar, a_bufs.index(index), pred)
tma.async_load(b_desc, [k, off_n], bar, b_bufs.index(index), pred)
return producer
@gluon.jit
def issue_mma(consumer, mma, bars, a_bufs, b_bufs, num_buffers: gl.constexpr):
index = consumer % num_buffers
phase = consumer // num_buffers & 1
consumer += 1
mbarrier.wait(bars.index(index), phase)
mma = mma.wait_num_outstanding(0)
mma = mma.issue_async_mma(a_bufs.index(index), b_bufs.index(index))
return consumer, mma
@gluon.jit
def matmul_pipelined_kernel(a_desc, b_desc, c_desc, MMAImpl: gl.constexpr, num_buffers: gl.constexpr,
num_warps: gl.constexpr):
BLOCK_M: gl.constexpr = c_desc.block_type.shape[0]
BLOCK_N: gl.constexpr = c_desc.block_type.shape[1]
BLOCK_K: gl.constexpr = a_desc.block_type.shape[1]
dtype: gl.constexpr = a_desc.dtype
K = a_desc.shape[1]
gl.static_assert(num_buffers >= 2, "expected at least 2 buffers")
a_bufs = gl.allocate_shared_memory(dtype, [num_buffers] + a_desc.block_type.shape, a_desc.layout)
b_bufs = gl.allocate_shared_memory(dtype, [num_buffers] + b_desc.block_type.shape, b_desc.layout)
bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
for i in gl.static_range(num_buffers):
mbarrier.init(bars.index(i), count=1)
# Separate producer and consumer indices, to support more than 2 buffers.
producer = 0
consumer = 0
pid_m = gl.program_id(axis=0)
pid_n = gl.program_id(axis=1)
off_m = pid_m * BLOCK_M
off_n = pid_n * BLOCK_N
# Use our MMA abstraction!
mma = MMAImpl.initialize(dtype, BLOCK_M, BLOCK_N, num_warps)
# Prefetch at most num_buffers-2 loads to allow the MMA to overlap.
for k in gl.static_range(0, BLOCK_K * (num_buffers - 2), BLOCK_K):
producer = issue_loads(producer, a_desc, b_desc, off_m, off_n, k, bars, a_bufs, b_bufs, num_buffers)
for k in range(BLOCK_K * (num_buffers - 2), K, BLOCK_K):
producer = issue_loads(producer, a_desc, b_desc, off_m, off_n, k, bars, a_bufs, b_bufs, num_buffers)
consumer, mma = issue_mma(consumer, mma, bars, a_bufs, b_bufs, num_buffers)
for _ in gl.static_range(num_buffers - 2):
consumer, mma = issue_mma(consumer, mma, bars, a_bufs, b_bufs, num_buffers)
mma = mma.wait_num_outstanding(0)
c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout)
c, mma = mma.take_result()
c_smem.store(c.to(dtype))
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [off_m, off_n], c_smem)
tma.store_wait(pendings=0)
def matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, num_warps):
MMAImpl = select_mma_impl()
M, N = C.shape
a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16)
b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], gl.float16)
c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16)
a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout)
b_desc = TensorDescriptor.from_tensor(B, [BLOCK_K, BLOCK_N], b_layout)
c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout)
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
matmul_pipelined_kernel[grid](a_desc, b_desc, c_desc, MMAImpl, num_buffers, num_warps=num_warps)
@pytest.mark.parametrize("M, N, K", [(2000, 1000, 2000)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 64)])
@pytest.mark.parametrize("num_buffers", [2, 3, 4])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer")
def test_pipelined_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, num_warps):
torch.manual_seed(0)
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, num_warps)
torch.testing.assert_close(A @ B, C, rtol=1e-3, atol=1e-1)
The optimal block shapes for our kernel are BLOCK_M=128 and BLOCK_N=256, which gives the maximum instruction shape on both Blackwell and Hopper. However, on Hopper we need 8 warps to fit the accumulator in registers.
if __name__ == "__main__":
M, N, K = 8192, 8192, 16 * 1024
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
if __name__ == "__main__" and not profiling_with_ncu:
BLOCK_M = 128
BLOCK_N = 256
is_hopper = torch.cuda.get_device_capability()[0] == 9
warps = [8] if is_hopper else [4, 8]
print("Benchmarking pipelined matmul")
print("=============================")
print(f"BLOCK_M={BLOCK_M}, BLOCK_N={BLOCK_N}")
print("BLOCK_K num_buffers num_warps tflops/s")
for (BLOCK_K, num_buffers), num_warps in itertools.product([(128, 2), (64, 3), (64, 4)], warps):
print(f"{BLOCK_K:>7} {num_buffers:>11} {num_warps:>9}", end=" ")
fn = lambda: matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, num_warps)
ms = triton.testing.do_bench_cudagraph(fn)
print(f"{get_flops(ms, M, N, K):8.2f}")
print()
BLOCK_K num_buffers num_warps Blackwell Hopper
128 2 4 735.96
128 2 8 697.97 489.26
64 3 4 1054.00
64 3 8 973.94 673.67
64 4 4 1175.70
64 4 8 1072.83 669.16
Blackwell performance lines up with what we have seen in previous tutorials, but on Hopper we see some wins. On Hopper, performance plateaus at 3 buffers, but on Blackwell we see benefits of 4 buffers. This suggests the throughput ratio has increased in favour of MMAs from Hopper to Blackwell. Noteworthy is our kernels are occupancy 1.
To make the kernel persistent, all we have to do is put an outer loop around the kernel and iterate over the output tiles assigned to that kernel.
Let’s define a tile scheduler abstraction that will allow us to change the scheduling strategy, starting with a basic row-major tile scheduler.
@gluon.aggregate
class PersistentTileScheduler:
pid_start: gl.tensor
pid_end: gl.tensor
num_pid_m: gl.tensor
@gluon.jit
def initialize(M, N, BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr):
kernel_id = gl.program_id(axis=0)
num_kernels = gl.num_programs(axis=0)
num_pid_m = gl.cdiv(M, BLOCK_M)
num_pid_n = gl.cdiv(N, BLOCK_N)
num_pid = num_pid_m * num_pid_n
pid_per_kernel = gl.cdiv(num_pid, num_kernels)
pid_start = kernel_id * pid_per_kernel
pid_end = min(pid_start + pid_per_kernel, num_pid)
return PersistentTileScheduler(pid_start, pid_end, num_pid_m)
@gluon.jit
def get_num_tiles(self):
return self.pid_end - self.pid_start
@gluon.jit
def get_tile(self, idx):
# Delinearize the tile ID along M.
pid = self.pid_start + idx
pid_m = pid % self.num_pid_m
pid_n = pid // self.num_pid_m
return pid_m, pid_n
We can make the kernel persistent by literally placing the outer loop around the whole kernel, but let’s re-use the TMA barrier and MMA state. We must scope the operand buffers to the inner loop so the shared memory allocator knows their liveranges do not intersect with the TMA store buffer.
@gluon.jit
def persistent_matmul_kernel(a_desc, b_desc, c_desc, MMAImpl: gl.constexpr, SchedulerImpl: gl.constexpr,
num_buffers: gl.constexpr, num_warps: gl.constexpr):
BLOCK_M: gl.constexpr = c_desc.block_type.shape[0]
BLOCK_N: gl.constexpr = c_desc.block_type.shape[1]
BLOCK_K: gl.constexpr = a_desc.block_type.shape[1]
dtype: gl.constexpr = a_desc.dtype
K = a_desc.shape[1]
bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
for i in gl.static_range(num_buffers):
mbarrier.init(bars.index(i), count=1)
# Producer and consumer indices.
producer = 0
consumer = 0
mma = MMAImpl.initialize(dtype, BLOCK_M, BLOCK_N, num_warps)
scheduler = SchedulerImpl.initialize(c_desc.shape[0], c_desc.shape[1], BLOCK_M, BLOCK_N)
for idx in range(scheduler.get_num_tiles()):
pid_m, pid_n = scheduler.get_tile(idx)
off_m = pid_m * BLOCK_M
off_n = pid_n * BLOCK_N
a_bufs = gl.allocate_shared_memory(dtype, [num_buffers] + a_desc.block_type.shape, a_desc.layout)
b_bufs = gl.allocate_shared_memory(dtype, [num_buffers] + b_desc.block_type.shape, b_desc.layout)
for k in gl.static_range(0, BLOCK_K * (num_buffers - 2), BLOCK_K):
producer = issue_loads(producer, a_desc, b_desc, off_m, off_n, k, bars, a_bufs, b_bufs, num_buffers)
for k in range(BLOCK_K * (num_buffers - 2), K, BLOCK_K):
producer = issue_loads(producer, a_desc, b_desc, off_m, off_n, k, bars, a_bufs, b_bufs, num_buffers)
consumer, mma = issue_mma(consumer, mma, bars, a_bufs, b_bufs, num_buffers)
for _ in gl.static_range(num_buffers - 2):
consumer, mma = issue_mma(consumer, mma, bars, a_bufs, b_bufs, num_buffers)
mma = mma.wait_num_outstanding(0)
c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout)
c, mma = mma.take_result()
c_smem.store(c.to(dtype))
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [off_m, off_n], c_smem)
tma.store_wait(pendings=0)
def persistent_matmul(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, num_warps, SchedulerImpl):
M, N = C.shape
MMAImpl = select_mma_impl()
a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16)
b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], gl.float16)
c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16)
a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout)
b_desc = TensorDescriptor.from_tensor(B, [BLOCK_K, BLOCK_N], b_layout)
c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout)
num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
num_pid = triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)
grid = (min(num_sms, num_pid), )
persistent_matmul_kernel[grid](a_desc, b_desc, c_desc, MMAImpl, SchedulerImpl, num_buffers, num_warps=num_warps)
schedulers = [PersistentTileScheduler]
@pytest.mark.parametrize("M, N, K", [(2000, 1000, 2000)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 256, 64)])
@pytest.mark.parametrize("num_buffers", [2, 3, 4])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.parametrize("SchedulerImpl", schedulers)
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer")
def test_persistent_matmul(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, num_warps, SchedulerImpl):
torch.manual_seed(0)
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
persistent_matmul(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, num_warps, SchedulerImpl)
torch.testing.assert_close(A @ B, C, rtol=1e-3, atol=1e-1)
if __name__ == "__main__" and not profiling_with_ncu:
print("Benchmarking persistent matmul")
print("==============================")
print(f"BLOCK_M={BLOCK_M} BLOCK_N={BLOCK_N}")
print("BLOCK_K num_buffers num_warps tflops/s")
for (BLOCK_K, num_buffers), num_warps in itertools.product([(128, 2), (64, 3), (64, 4)], warps):
print(f"{BLOCK_K:>7} {num_buffers:>11} {num_warps:>9}", end=" ")
fn = lambda: persistent_matmul(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, num_warps,
PersistentTileScheduler)
ms = triton.testing.do_bench_cudagraph(fn)
print(f"{get_flops(ms, M, N, K):8.2f}")
print()
BLOCK_K num_buffers num_warps Blackwell Hopper
128 2 4 712.25
128 2 8 686.64 502.84
64 3 4 1032.16
64 3 8 938.81 661.11
64 4 4 1142.26
64 4 8 1071.46 658.84
The Hopper kernel sees a modest improvement, but the Blackwell kernel
performance is slightly lower. Let’s capture a profile of the kernels on
Blackwell using ncu. Pass profile to this script’s arguments to run the two
kernels once.
if __name__ == "__main__" and profiling_with_ncu:
matmul_pipelined(A, B, C, 128, 256, 64, 4, 4)
persistent_matmul(A, B, C, 128, 256, 64, 4, 4, PersistentTileScheduler)
There are many reasons the persistent kernel can be slower. Load imbalance can arise due to inefficient scheduling (work is not evenly distributed). But it can also arise from drift at runtime, such as some TMA accesses taking longer than others, which a static tile scheduler cannot compensate for.
Another reason we suspect is the global memory access pattern:
ncu --set full -o pipelined --kernel-name matmul_pipelined_kernel python 07-persistence.py profile
ncu --set full -o persistent --kernel-name persistent_matmul_kernel python 07-persistence.py profile
ncu --import pipelined.ncu-rep | grep "L2 Hit Rate"
L2 Hit Rate % 61.11
ncu --import persistent.ncu-rep | grep "L2 Hit Rate"
L2 Hit Rate % 52.93
The persistent kernel’s L2 hit rate is 10% lower. We can improve L2 efficiency by “super-grouping” the tiles along columns. See 03-matrix-multiplication.py for more details. Let’s encode this strategy in a new tile scheduler.
def GroupedPersistentTileScheduler(GROUP_SIZE_M):
# Bind this as a constexpr so it can be captured.
GROUP_SIZE_M = gl.constexpr(GROUP_SIZE_M)
# Like C++ templates!
@gluon.aggregate
class GroupedPersistentTileSchedulerImpl:
start_pid: gl.tensor
num_pid_m: gl.tensor
num_pid_in_group: gl.tensor
num_pid: gl.tensor
@gluon.jit
def initialize(M, N, BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr):
start_pid = gl.program_id(axis=0)
num_pid_m = gl.cdiv(M, BLOCK_M)
num_pid_n = gl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
num_pid = num_pid_m * num_pid_n
return GroupedPersistentTileSchedulerImpl(start_pid, num_pid_m, num_pid_in_group, num_pid)
@gluon.jit
def get_num_tiles(self):
return gl.cdiv(self.num_pid - self.start_pid, gl.num_programs(axis=0))
@gluon.jit
def get_tile(self, idx):
tile_id = self.start_pid + idx * gl.num_programs(axis=0)
group_id = tile_id // self.num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(self.num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % self.num_pid_in_group) // group_size_m
return pid_m, pid_n
GroupedPersistentTileSchedulerImpl.__name__ = f"GroupedPersistentTileScheduler({GROUP_SIZE_M.value})"
return GroupedPersistentTileSchedulerImpl
# Add this to the testsuite.
schedulers += [GroupedPersistentTileScheduler(1), GroupedPersistentTileScheduler(8)]
if __name__ == "__main__" and not profiling_with_ncu:
num_warps = 8 if is_hopper else 4
num_buffers = 3 if is_hopper else 4
print("Benchmarking grouped scheduler")
print("=============================")
print(f"BLOCK_M={BLOCK_M} BLOCK_N={BLOCK_N} BLOCK_K={BLOCK_K}")
print(f"num_buffers={num_buffers} num_warps={num_warps}")
print("GROUP_SIZE_M tflops/s")
for GROUP_SIZE_M in [1, 2, 4, 6, 8]:
print(f"{GROUP_SIZE_M:>12}", end=" ")
fn = lambda: persistent_matmul(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, num_warps,
GroupedPersistentTileScheduler(GROUP_SIZE_M))
ms = triton.testing.do_bench_cudagraph(fn)
print(f"{get_flops(ms, M, N, K):8.2f}")
print()
GROUP_SIZE_M Blackwell Hopper
1 1025.11 649.09
2 1050.43 651.32
4 1032.71 655.51
6 1057.27 652.39
8 1179.94 648.42
At GROUP_SIZE_M=8, we recover performance on Blackwell. In fact, under ncu we see the L2 hit rate increases to 70%, which suggests there are other ways to improve the scheduling.
Performance decreases on Hopper with this scheduler. The L2 hit rate of the persistent kernel is 86% and 89% for the non-persistent kernel. The grouped scheduler does not affect the L2 hit rate but it does increase load imbalance.
Pipelining across the outer loop benefits smaller K shapes more because a larger proportion of time is spent in the epilogue. We can try overlapping the TMA store with the next tile by rotating the TMA store wait.
However, this causes the liverange of the TMA store buffer to overlap with the operand buffers, decreasing our max num_buffers to 3. While Hopper is fine with 3 buffers, on Blackwell performance can suffer. There are 3 remedies:
Use gl.store which does not require shared memory but it cannot be pipelined. However, the layout conversion requires shared memory.
Break up the TMA store to multiple steps, allowing us to use smaller buffers, we will only be able to pipeline the last step. reduces the amount of overlap.
Borrow one of the b_bufs.
For BLOCK_{M,N,K} = (128, 256, 64), one B buffer is half the size of the accumulator, but we have enough memory to use 5 buffers for B just so that we can steal two buffers for the epilogue, even though the inner loop only uses 4 at a time.
# Forked versions of issue_loads and issue_mma that support `stealb`.
@gluon.jit
def issue_loads_stealb(producer, a_desc, b_desc, off_m, off_n, k, bars, a_bufs, b_bufs, stealb: gl.constexpr,
num_buffers: gl.constexpr, pred=True):
index = producer % num_buffers
b_index = producer % (num_buffers + stealb)
producer += 1
bar = bars.index(index)
mbarrier.expect(bar, a_desc.block_type.nbytes + b_desc.block_type.nbytes, pred=pred)
tma.async_load(a_desc, [off_m, k], bar, a_bufs.index(index), pred)
tma.async_load(b_desc, [k, off_n], bar, b_bufs.index(b_index), pred)
return producer
@gluon.jit
def issue_mma_stealb(consumer, mma, bars, a_bufs, b_bufs, stealb: gl.constexpr, num_buffers: gl.constexpr):
index = consumer % num_buffers
b_index = consumer % (num_buffers + stealb)
phase = consumer // num_buffers & 1
consumer += 1
mbarrier.wait(bars.index(index), phase)
mma = mma.wait_num_outstanding(0)
mma = mma.issue_async_mma(a_bufs.index(index), b_bufs.index(b_index))
return consumer, mma
@gluon.jit
def persistent_matmul_pipelined_kernel(a_desc, b_desc, c_desc, MMAImpl: gl.constexpr, SchedulerImpl: gl.constexpr,
num_buffers: gl.constexpr, STEALB: gl.constexpr, num_warps: gl.constexpr):
BLOCK_M: gl.constexpr = c_desc.block_type.shape[0]
BLOCK_N: gl.constexpr = c_desc.block_type.shape[1]
BLOCK_K: gl.constexpr = a_desc.block_type.shape[1]
dtype: gl.constexpr = a_desc.dtype
K = a_desc.shape[1]
# All buffers share the same liverange.
gl.static_assert(num_buffers >= 3, "expected at least 3 buffers")
a_bufs = gl.allocate_shared_memory(dtype, [num_buffers] + a_desc.block_type.shape, a_desc.layout)
# Add an extra B buffer when stealing.
b_bufs = gl.allocate_shared_memory(dtype, [num_buffers + STEALB] + b_desc.block_type.shape, b_desc.layout)
if not STEALB:
c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout)
else:
gl.static_assert(2 * BLOCK_N * BLOCK_K >= BLOCK_M * BLOCK_N, "B tile not large enough to steal")
bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
for i in gl.static_range(num_buffers):
mbarrier.init(bars.index(i), count=1)
producer = 0
consumer = 0
mma = MMAImpl.initialize(dtype, BLOCK_M, BLOCK_N, num_warps)
scheduler = SchedulerImpl.initialize(c_desc.shape[0], c_desc.shape[1], BLOCK_M, BLOCK_N)
num_tiles = scheduler.get_num_tiles()
# Peeled inner loop prologue.
idx = 0
pid_m, pid_n = scheduler.get_tile(idx)
off_m = pid_m * BLOCK_M
off_n = pid_n * BLOCK_N
for ki in gl.static_range(0, BLOCK_K * (num_buffers - 2), BLOCK_K):
producer = issue_loads_stealb(producer, a_desc, b_desc, off_m, off_n, ki, bars, a_bufs, b_bufs, STEALB,
num_buffers)
k = BLOCK_K * (num_buffers - 2)
producer = issue_loads_stealb(producer, a_desc, b_desc, off_m, off_n, k, bars, a_bufs, b_bufs, STEALB, num_buffers)
for _ in range(num_tiles):
consumer, mma = issue_mma_stealb(consumer, mma, bars, a_bufs, b_bufs, STEALB, num_buffers)
if STEALB:
# Wait for the epilogue before the first TMA load.
tma.store_wait(pendings=0)
for k in range(BLOCK_K * (num_buffers - 1), K, BLOCK_K):
producer = issue_loads_stealb(producer, a_desc, b_desc, off_m, off_n, k, bars, a_bufs, b_bufs, STEALB,
num_buffers)
consumer, mma = issue_mma_stealb(consumer, mma, bars, a_bufs, b_bufs, STEALB, num_buffers)
epilogue_off_m = off_m
epilogue_off_n = off_n
# Peel the next prologue and fuse it with the pipeline drain loop.
idx += 1
pid_m, pid_n = scheduler.get_tile(idx)
off_m = pid_m * BLOCK_M
off_n = pid_n * BLOCK_N
# Predicate the peeled prologue instead of using a conditional.
pred = idx < num_tiles
for ki in gl.static_range(0, BLOCK_K * (num_buffers - 2), BLOCK_K):
producer = issue_loads_stealb(producer, a_desc, b_desc, off_m, off_n, ki, bars, a_bufs, b_bufs, STEALB,
num_buffers, pred)
consumer, mma = issue_mma_stealb(consumer, mma, bars, a_bufs, b_bufs, STEALB, num_buffers)
k = BLOCK_K * (num_buffers - 2)
producer = issue_loads_stealb(producer, a_desc, b_desc, off_m, off_n, k, bars, a_bufs, b_bufs, STEALB,
num_buffers)
mma = mma.wait_num_outstanding(0)
c, mma = mma.take_result()
c = c.to(dtype)
if not STEALB:
c_buf = c_smem
tma.store_wait(pendings=0)
else:
# Steal the next 2 B buffers for the epilogue.
c_buf = b_bufs.index(producer % (num_buffers + STEALB))._reinterpret(dtype, c_desc.block_type.shape,
c_desc.layout)
c_buf.store(c)
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [epilogue_off_m, epilogue_off_n], c_buf)
tma.store_wait(pendings=0)
def persistent_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, num_warps, SchedulerImpl):
M, N = C.shape
MMAImpl = select_mma_impl()
a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16)
b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], gl.float16)
c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16)
a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout)
b_desc = TensorDescriptor.from_tensor(B, [BLOCK_K, BLOCK_N], b_layout)
c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout)
num_sms = torch.cuda.get_device_properties("cuda").multi_processor_count
num_pid = triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)
grid = (min(num_sms, num_pid), )
persistent_matmul_pipelined_kernel[grid](a_desc, b_desc, c_desc, MMAImpl, SchedulerImpl, num_buffers,
STEALB=num_buffers == 4, num_warps=num_warps)
@pytest.mark.parametrize("M, N, K", [(208, 416, 304), (2000, 1000, 2000)])
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(64, 64, 64), (128, 256, 64)])
@pytest.mark.parametrize("num_buffers", [3, 4])
@pytest.mark.parametrize("num_warps", [4, 8])
@pytest.mark.parametrize("SchedulerImpl", schedulers)
@pytest.mark.skipif(not is_hopper_or_newer(), reason="Requires Hopper or newer")
def test_persistent_matmul_pipelined(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, num_warps, SchedulerImpl):
torch.manual_seed(0)
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
persistent_matmul_pipelined(A, B, C, BLOCK_M, BLOCK_N, BLOCK_K, num_buffers, num_warps, SchedulerImpl)
torch.testing.assert_close(A @ B, C, rtol=1e-3, atol=1e-1)
if __name__ == "__main__":
args = {
"BLOCK_M": 128,
"BLOCK_N": 256,
"BLOCK_K": 64,
"num_buffers": 3 if is_hopper else 4,
"num_warps": 8 if is_hopper else 4,
}
scheduler = PersistentTileScheduler if is_hopper else GroupedPersistentTileScheduler(8)
nonpersistent = partial(matmul_pipelined, **args)
persistent = partial(persistent_matmul, **args, SchedulerImpl=scheduler)
persistent_pipelined = partial(persistent_matmul_pipelined, **args, SchedulerImpl=scheduler)
M, N = 8192, 8192
C = torch.empty(M, N, device="cuda", dtype=torch.float16)
print("Benchmarking pipelined persistent")
print("=================================")
print(" K nonpersistent persistent pipelined cublas")
for K in [2**i for i in range(9, 15)]:
as_flops = partial(get_flops, M=M, N=N, K=K)
A = torch.randn(M, K, device="cuda", dtype=torch.float16)
B = torch.randn(K, N, device="cuda", dtype=torch.float16)
BT = B.T.contiguous()
r0 = as_flops(triton.testing.do_bench_cudagraph(lambda: nonpersistent(A, B, C)))
r1 = as_flops(triton.testing.do_bench_cudagraph(lambda: persistent(A, B, C)))
r2 = as_flops(triton.testing.do_bench_cudagraph(lambda: persistent_pipelined(A, B, C)))
r3 = as_flops(triton.testing.do_bench(lambda: cublas.matmul(A, BT, C)))
print(f"{K:>5} {r0:>17.2f} {r1:>13.2f} {r2:>11.2f} {r3:>9.2f}")
Blackwell results:
K nonpersistent persistent pipelined cublas
512 615.86 828.70 993.50 1108.11
1024 997.16 1077.28 1173.31 1347.44
2048 1152.74 1190.55 1133.37 1435.01
4096 1164.05 1120.92 1143.47 1563.98
8192 1160.93 1074.97 1185.40 1491.84
16384 1185.62 1096.34 1296.93 1548.42
Hopper results:
K nonpersistent persistent pipelined cublas
512 491.74 485.01 539.88 588.15
1024 554.24 575.02 602.52 588.32
2048 573.87 594.72 625.91 615.58
4096 609.36 630.10 640.48 646.30
8192 629.44 646.22 661.57 661.11
16384 653.79 660.29 670.00 665.49
Persistent matmul, when pipelined, gains more performance relative to nonpersistent at lower K, as we would expect. Load balancing can be particularly difficult when the number of SMs do not evenly divide the number of blocks, and with 8192x8192, we are smack in the middle with ~13.5 and ~15.5 blocks per SM for Hopper and Blackwell, respectively.
On Hopper, our pipelined kernel is competitive with cublas, even pulling ahead for medium-sized K. However, cublas has a definitive advantage at low K. On Blackwell, it’s not even close: cublas is significantly faster.
Some matmul performance takes:
On Hopper, software pipelining is sufficient to reach peak performance for medium and large K.
cublas uses 2-CTA matmul, which uses distributed shared memory to allow 256x256 instruction shape. 2-CTA support in Gluon is very spotty, but this enables cublas to more efficiently feed the MMA, which matters more on Blackwell due to the relative increase in MMA throughput vs TMA.
cublas matmul is warp-specialized which is necessary on Hopper to fully overlap the epilogue at small K.
Our Blackwell implementation is limited by the shared API we designed for Hopper and Blackwell: we are not double-buffering the accumulator and leaving 256 columns of TMEM unused.
On Blackwell, we can use
clusterlaunchcontrolto dynamically schedule work in conjunction with the GPU, getting the best of both worlds. This is explored further in tutorial 12.
Main takeaways:
Persistent kernels replace GPU block scheduling with a (typically) static schedule. This allows more resource and compute coordination/overlap between blocks at the cost of losing dynamic scheduling.
Persistent kernels tend to benefit smaller problem sizes, but still deliver benefits for large problem sizes.