Cluster Launch Control (CLC)

Cluster Launch Control (CLC) is a Blackwell (SM100+) hardware feature that enables dynamic work distribution between thread blocks. When a block finishes early, it can cancel a not-yet-launched cluster and take over its work, improving load balancing.

This tutorial demonstrates:

  1. The CLC API: try_cancel, is_canceled, get_first_ctaid

  2. How to overlap CLC with computation to hide latency

  3. A comparison with a statically scheduled persistent matmul

Key Insight

The critical optimization is issuing CLC during the TMA prologue and checking the result after tile completion. This hides CLC latency behind computation.

CLC API

  • clc.try_cancel(result, mbar): Issue async CLC request to cancel a pending cluster

  • clc_result = clc.load_result(result): Load CLC response into registers

  • clc_result.is_canceled(): Returns True if a cluster was successfully canceled

  • clc_result.program_id(dim): Get the canceled cluster’s program ID


import torch
import triton
import pytest
import importlib

from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.nvidia.blackwell import TensorDescriptor
from triton.experimental.gluon.language.nvidia.blackwell import tma, mbarrier, fence_async_shared, clc


def is_blackwell():
    return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 10


if __name__ == "__main__" and not is_blackwell():
    raise RuntimeError("This tutorial requires Blackwell (SM100+) GPU")

# Re-use helpers from tutorial 7.
t7 = importlib.import_module("07-persistence")

CLC Matmul Kernel

This kernel processes its assigned tile, then attempts to steal additional work. CLC is issued during the prologue so the result is ready after tile completion.

This is identical to the persistent_matmul_kernel from tutorial 7, except for the changed ClcTileScheduler interface to support dynamic scheduling.



@gluon.aggregate
class ClcTileScheduler:
    has_work: gl.tensor
    tile_id: gl.tensor
    clc_result_buf: gl.shared_memory_descriptor
    barrier: gl.shared_memory_descriptor
    phase: gl.tensor

    @gluon.jit
    def initialize(M, N, BLOCK_M, BLOCK_N):
        has_work = gl.to_tensor(True)
        starting_tile_id = gl.program_id(0)

        barrier = mbarrier.allocate_mbarrier()
        clc_result_buffer = gl.allocate_shared_memory(gl.int64, [2], gl.SwizzledSharedLayout(1, 1, 1, [0]))
        mbarrier.init(barrier, count=1)
        phase = gl.to_tensor(0)
        return ClcTileScheduler(has_work, starting_tile_id, clc_result_buffer, barrier, phase)

    @gluon.jit
    def try_cancel(self) -> None:
        clc.try_cancel(self.clc_result_buf, self.barrier)
        mbarrier.expect(self.barrier, 16)

    @gluon.jit
    def advance(self):
        mbarrier.wait(self.barrier, self.phase)
        clc_res = clc.load_result(self.clc_result_buf)
        has_work = clc_res.is_canceled()
        next_tile_id = clc_res.program_id(0)
        return ClcTileScheduler(has_work, next_tile_id, self.clc_result_buf, self.barrier, self.phase ^ 1)

We also implement a static scheduler that conforms to the same interface, so we can directly compare the benifits of dynamic scheduling.



@gluon.aggregate
class StaticTileScheduler:
    has_work: gl.tensor
    tile_id: gl.tensor
    num_tiles: gl.tensor

    @gluon.jit
    def initialize(M, N, BLOCK_M, BLOCK_N):
        starting_tile_id = gl.program_id(0)
        num_tiles = gl.cdiv(M, BLOCK_M) * gl.cdiv(N, BLOCK_N)
        has_work = starting_tile_id < num_tiles
        return StaticTileScheduler(has_work, starting_tile_id, num_tiles)

    @gluon.jit
    def try_cancel(self) -> None:
        pass

    @gluon.jit
    def advance(self):
        next_tile_id = self.tile_id + gl.num_programs(0)
        has_work = next_tile_id < self.num_tiles
        return StaticTileScheduler(has_work, next_tile_id, self.num_tiles)


@triton.jit
def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M):
    group_id = tile_id // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + (tile_id % group_size_m)
    pid_n = (tile_id % num_pid_in_group) // group_size_m
    return pid_m, pid_n


@gluon.jit
def persistent_matmul_kernel(a_desc, b_desc, c_desc, MMAImpl: gl.constexpr, SchedulerImpl: gl.constexpr,
                             num_buffers: gl.constexpr, num_warps: gl.constexpr):
    BLOCK_M: gl.constexpr = c_desc.block_type.shape[0]
    BLOCK_N: gl.constexpr = c_desc.block_type.shape[1]
    BLOCK_K: gl.constexpr = a_desc.block_type.shape[1]
    dtype: gl.constexpr = a_desc.dtype
    K = a_desc.shape[1]
    N = c_desc.shape[0]
    M = c_desc.shape[1]

    num_pid_n = gl.cdiv(N, BLOCK_N)
    num_pid_m = gl.cdiv(M, BLOCK_M)
    GROUP_SIZE_M: gl.constexpr = 8
    num_pid_in_group = GROUP_SIZE_M * num_pid_n

    bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
    for i in gl.static_range(num_buffers):
        mbarrier.init(bars.index(i), count=1)
    # Producer and consumer indices.
    producer = 0
    consumer = 0

    mma = MMAImpl.initialize(dtype, BLOCK_M, BLOCK_N, num_warps)
    scheduler = SchedulerImpl.initialize(M, N, BLOCK_M, BLOCK_N)
    while scheduler.has_work:
        pid_m, pid_n = _compute_pid(scheduler.tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M)
        off_m = pid_m * BLOCK_M
        off_n = pid_n * BLOCK_N

        scheduler.try_cancel()

        a_bufs = gl.allocate_shared_memory(dtype, [num_buffers] + a_desc.block_type.shape, a_desc.layout)
        b_bufs = gl.allocate_shared_memory(dtype, [num_buffers] + b_desc.block_type.shape, b_desc.layout)
        for k in gl.static_range(0, BLOCK_K * (num_buffers - 2), BLOCK_K):
            producer = t7.issue_loads(producer, a_desc, b_desc, off_m, off_n, k, bars, a_bufs, b_bufs, num_buffers)

        for k in range(BLOCK_K * (num_buffers - 2), K, BLOCK_K):
            producer = t7.issue_loads(producer, a_desc, b_desc, off_m, off_n, k, bars, a_bufs, b_bufs, num_buffers)
            consumer, mma = t7.issue_mma(consumer, mma, bars, a_bufs, b_bufs, num_buffers)

        for _ in gl.static_range(num_buffers - 2):
            consumer, mma = t7.issue_mma(consumer, mma, bars, a_bufs, b_bufs, num_buffers)

        mma = mma.wait_num_outstanding(0)
        c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout)
        c, mma = mma.take_result()
        c_smem.store(c.to(dtype))
        fence_async_shared()
        tma.async_copy_shared_to_global(c_desc, [off_m, off_n], c_smem)
        tma.store_wait(pendings=0)
        scheduler = scheduler.advance()


def run_matmul_kernel(A, B, C, BLOCK_M=128, BLOCK_N=256, BLOCK_K=64, num_buffers=3, num_warps=4, use_clc=True):
    M, N = C.shape
    a_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], gl.float16)
    b_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], gl.float16)
    c_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_N], gl.float16)
    a_desc = TensorDescriptor.from_tensor(A, [BLOCK_M, BLOCK_K], a_layout)
    b_desc = TensorDescriptor.from_tensor(B, [BLOCK_K, BLOCK_N], b_layout)
    c_desc = TensorDescriptor.from_tensor(C, [BLOCK_M, BLOCK_N], c_layout)

    if use_clc:
        num_pid_m = triton.cdiv(M, BLOCK_M)
        num_pid_n = triton.cdiv(N, BLOCK_N)
        grid = num_pid_m * num_pid_n
        SchedulerImpl = ClcTileScheduler
    else:
        dev_props = torch.cuda.get_device_properties(A.device)
        grid = dev_props.multi_processor_count
        SchedulerImpl = StaticTileScheduler

    MMAImpl = t7.MMAv5
    persistent_matmul_kernel[(grid, )](a_desc, b_desc, c_desc, MMAImpl, SchedulerImpl, num_buffers, num_warps=num_warps)


@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell")
@pytest.mark.parametrize("M, N, K", [(8192, 8192, 8192), (1000, 1000, 1000)])
@pytest.mark.parametrize("use_clc", [True, False])
def test_op(M, N, K, use_clc):
    A = torch.randn(M, K, device='cuda', dtype=torch.float16)
    B = torch.randn(K, N, device='cuda', dtype=torch.float16)
    C = torch.empty(M, N, device='cuda', dtype=torch.float16)

    C_ref = torch.mm(A, B)
    run_matmul_kernel(A, B, C, use_clc=use_clc)

    torch.testing.assert_close(C, C_ref)

Benchmark



def benchmark():
    print("=" * 60)
    print("Cluster Launch Control (CLC) Matmul - Blackwell")
    print("=" * 60)

    props = torch.cuda.get_device_properties(0)
    print(f"Device: {props.name}, SMs: {props.multi_processor_count}")

    M, N, K = 8192, 8192, 8192
    print(f"Matrix: {M}x{N}x{K}")

    A = torch.randn(M, K, device='cuda', dtype=torch.float16)
    B = torch.randn(K, N, device='cuda', dtype=torch.float16)
    C = torch.empty(M, N, device='cuda', dtype=torch.float16)

    # Static baseline
    def static_fn():
        run_matmul_kernel(A, B, C, use_clc=False)

    ms = triton.testing.do_bench_cudagraph(static_fn)
    static_tflops = t7.get_flops(ms, M, N, K)
    print(f"\nStatic:     {static_tflops:7.2f} TFLOPS")

    # CLC matmul
    def clc_fn():
        run_matmul_kernel(A, B, C)

    ms = triton.testing.do_bench_cudagraph(clc_fn)
    clc_tflops = t7.get_flops(ms, M, N, K)
    print(f"CLC:        {clc_tflops:7.2f} TFLOPS ({100*clc_tflops/static_tflops:.1f}% of static)")

    # Correctness check
    print("\nVerifying correctness...")
    A_test = torch.randn(1024, 1024, device='cuda', dtype=torch.float16)
    B_test = torch.randn(1024, 1024, device='cuda', dtype=torch.float16)
    C_ref = torch.mm(A_test, B_test)
    C_clc = torch.empty_like(C_ref)
    run_matmul_kernel(A_test, B_test, C_clc)
    torch.cuda.synchronize()
    max_diff = (C_ref - C_clc).abs().max().item()
    print(f"Max diff: {max_diff:.6f}")
    assert max_diff < 1.0, "Correctness check failed"
    print("✓ Correctness verified")

A sample run of this benchmark may look like,

Cluster Launch Control (CLC) Matmul - Blackwell

Device: NVIDIA GB200, SMs: 152 Matrix: 8192x8192x8192

Static: 1040.13 TFLOPS CLC: 1080.74 TFLOPS (103.9% of static)

Notice that we’ve achieved a 3.9% speedup (which will vary run to run), without improving the actual matmul computation at all. This is because there is always a slight variance between the time taken to compute each tile. For example, one may have inputs already cached in L2 and another might suffer a cache miss. With a static scheduler, the kernel takes as long as it takes the slowest SM to complete it’s assigned num_tiles / num_sms tiles. However, CLC allows us to better balance the load by give more work to the SMs that finish early and less to those that are taking longer.

This effect will be even more pronounced in kernels that have more run-time variation, e.g. in a ragged matmul where the k dim is different for different output tiles.

Note that a similar effect can be achieved on pre-blackwell by using a global atomic counter to track the next available tile id. However, this requires additional run time overhead to zero out the counter before launching the kernel which may nullify the benefit for reasonably-balanced workloads.


if __name__ == "__main__":
    benchmark()