Attention Forward

This example can be found at python/examples/gluon/01-attention-forward.py.

import copy
import math
import torch
import triton
import pytest
import itertools
from dataclasses import dataclass, fields

from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
from triton.experimental.gluon.language.nvidia.hopper import fence_async_shared
from triton.experimental.gluon.language.nvidia.blackwell import (
    TensorMemoryLayout,
    allocate_tensor_memory,
    tensor_memory_descriptor,
    tensor_memory_descriptor_type,
    tma,
    mbarrier,
    tcgen05_mma,
    tcgen05_commit,
    float2,
)
from triton.experimental.gluon.language.nvidia.blackwell.float2 import Float2Tensor

# ===-----------------------------------------------------------------------===#
# Layout Utilities
# ===-----------------------------------------------------------------------===#


@gluon.constexpr_function
def get_mma_instr_shape(shape, element_ty):
    m = 128 if shape[0] >= 128 else 64
    n = 256 if shape[1] >= 256 else shape[1]
    k = 256 // element_ty.primitive_bitwidth
    return (m, n, k)


# ===-----------------------------------------------------------------------===#
# Data Abstractions
# ===-----------------------------------------------------------------------===#


@gluon.aggregate
class BarrierCounter:
    index: gl.tensor
    phase: gl.tensor
    num_barriers: gl.constexpr

    @gluon.must_use_result
    @gluon.jit
    def increment(self):
        if self.num_barriers == 1:
            return BarrierCounter(gl.to_tensor(0), self.phase ^ 1, self.num_barriers)
        next_index = self.index + 1
        rollover = next_index == self.num_barriers
        index = gl.where(rollover, 0, next_index)
        phase = gl.where(rollover, self.phase ^ 1, self.phase)
        return BarrierCounter(index, phase, self.num_barriers)


def Channel(T, alloc_fn):

    @gluon.aggregate
    class ChannelType:
        mem: T
        ready_bars: gl.shared_memory_descriptor
        empty_bars: gl.shared_memory_descriptor
        num_buffers: gl.constexpr
        num_consumers: gl.constexpr

        @gluon.jit
        def alloc(shape: gl.constexpr, dtype: gl.constexpr, layout: gl.constexpr, num_buffers: gl.constexpr,
                  num_consumers: gl.constexpr = 1):
            mem = alloc_fn(dtype, [num_buffers] + shape, layout)
            ready_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
            empty_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
            for i in gl.static_range(num_buffers):
                mbarrier.init(ready_bars.index(i), count=1)
                mbarrier.init(empty_bars.index(i), count=num_consumers)
                mbarrier.arrive(empty_bars.index(i), count=num_consumers)
            return ChannelType(mem, ready_bars, empty_bars, num_buffers, num_consumers)

        @gluon.jit
        def acquire_producer(self, counter):
            index, phase = counter.index, counter.phase
            mem = self.mem.index(index)
            ready_bar = self.ready_bars.index(index)
            empty_bar = self.empty_bars.index(index)

            mbarrier.wait(empty_bar, phase)
            return mem, ready_bar

        @gluon.jit
        def acquire_consumer(self, counter):
            index, phase = counter.index, counter.phase
            mem = self.mem.index(index)
            ready_bar = self.ready_bars.index(index)
            empty_bar = self.empty_bars.index(index)

            mbarrier.wait(ready_bar, phase)
            return mem, empty_bar

        @gluon.jit
        def create_counter(self):
            return BarrierCounter(gl.to_tensor(0), gl.to_tensor(0), self.num_buffers)

        @gluon.jit
        def create_producer(self):
            return Producer(self, self.create_counter())

        @gluon.jit
        def create_consumer(self):
            return Consumer(self, self.create_counter())

        @gluon.jit
        def release(self):
            if isinstance(self.mem, gl.shared_memory_descriptor):
                self.mem._keep_alive()
            for i in gl.static_range(self.num_buffers):
                mbarrier.invalidate(self.ready_bars.index(i))
                mbarrier.invalidate(self.empty_bars.index(i))

    @gluon.aggregate
    class Producer:
        channel: ChannelType
        counter: BarrierCounter

        @gluon.jit
        def acquire(self):
            mem, ready_bar = self.channel.acquire_producer(self.counter)
            next = Producer(self.channel, self.counter.increment())
            return mem, ready_bar, next

    @gluon.aggregate
    class Consumer:
        channel: ChannelType
        counter: BarrierCounter

        @gluon.jit
        def acquire(self):
            mem, empty_bar = self.channel.acquire_consumer(self.counter)
            next = Consumer(self.channel, self.counter.increment())
            return mem, empty_bar, next

    return ChannelType, Producer, Consumer


SharedMemoryChannel, SharedMemoryProducer, SharedMemoryConsumer = Channel(gl.shared_memory_descriptor,
                                                                          gl.allocate_shared_memory)
TensorMemoryChannel, TensorMemoryProducer, TensorMemoryConsumer = Channel(tensor_memory_descriptor,
                                                                          allocate_tensor_memory)


@gluon.jit
def get_desc_channel(desc, num_buffers: gl.constexpr, num_consumers: gl.constexpr = 1):
    shape: gl.constexpr = desc.block_type.shape
    layout: gl.constexpr = desc.layout
    return SharedMemoryChannel.alloc(shape, desc.dtype, layout, num_buffers, num_consumers)


@gluon.jit
def issue_async_tma_load(smem, bar, desc, offset):
    mbarrier.expect(bar, desc.block_type.nbytes)
    tma.async_load(desc, [offset, 0], bar, smem)


# ===-----------------------------------------------------------------------===#
# Gluon Attention
# ===-----------------------------------------------------------------------===#


@gluon.aggregate
class AttentionConfig:
    qk_scale: gl.tensor
    Z: gl.tensor
    H: gl.tensor
    N_CTX: gl.tensor

    BLOCK_M: gl.constexpr
    BLOCK_N: gl.constexpr
    HEAD_DIM: gl.constexpr
    GROUP_SIZE_N: gl.constexpr
    NUM_SMS: gl.constexpr
    dtype: gl.constexpr
    num_warps: gl.constexpr

    SPLIT_D_FACTOR: gl.constexpr
    SPLIT_EXP_FACTOR: gl.constexpr
    SPLIT_QK_LOAD_FACTOR: gl.constexpr
    SPLIT_M: gl.constexpr
    SPLIT_D: gl.constexpr

    q_shape: gl.constexpr
    k_shape: gl.constexpr
    v_shape: gl.constexpr
    qk_shape: gl.constexpr
    o_shape: gl.constexpr

    qk_tmem_layout: gl.constexpr
    o_tmem_layout: gl.constexpr
    p_tmem_layout: gl.constexpr

    qk_layout: gl.constexpr
    o_splitn_layout: gl.constexpr
    alpha_2d_layout: gl.constexpr

    num_kv_buffers: gl.constexpr
    use_exp2_turnstile: gl.constexpr

    @gluon.constexpr_function
    def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, STAGE,
                 SPLIT_EXP_FACTOR, dtype, num_warps, NUM_KV_BUFFERS, USE_EXP2_TURNSTILE):
        self.qk_scale = qk_scale
        self.Z = Z
        self.H = H
        self.N_CTX = N_CTX

        self.BLOCK_M = gl.constexpr(BLOCK_M)
        self.BLOCK_N = gl.constexpr(BLOCK_N)
        self.HEAD_DIM = gl.constexpr(HEAD_DIM)
        self.GROUP_SIZE_N = gl.constexpr(GROUP_SIZE_N)
        self.NUM_SMS = gl.constexpr(NUM_SMS)
        self.dtype = gl.constexpr(dtype)
        self.num_warps = gl.constexpr(num_warps)

        self.SPLIT_D_FACTOR = gl.constexpr(2)
        self.SPLIT_EXP_FACTOR = gl.constexpr(SPLIT_EXP_FACTOR)
        self.SPLIT_QK_LOAD_FACTOR = gl.constexpr(2 if STAGE == 1 else 1)
        self.SPLIT_M = gl.constexpr(self.BLOCK_M // 2)
        self.SPLIT_D = gl.constexpr(self.HEAD_DIM // self.SPLIT_D_FACTOR)

        self.q_shape = gl.constexpr([self.SPLIT_M, self.HEAD_DIM])
        self.k_shape = gl.constexpr([self.BLOCK_N, self.HEAD_DIM])
        self.qk_shape = gl.constexpr([self.SPLIT_M, self.BLOCK_N])
        self.v_shape = gl.constexpr([self.BLOCK_N, self.HEAD_DIM])
        self.o_shape = gl.constexpr([self.SPLIT_M, self.HEAD_DIM])

        qk_instr_shape = get_mma_instr_shape(self.qk_shape, gl.float32)
        o_instr_shape = get_mma_instr_shape(self.o_shape, gl.float32)
        self.qk_tmem_layout = gl.constexpr(TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1))
        self.o_tmem_layout = gl.constexpr(TensorMemoryLayout((o_instr_shape[0], o_instr_shape[1]), col_stride=1))
        self.p_tmem_layout = gl.constexpr(TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1))
        o_splitn_tmem_layout: gl.constexpr = TensorMemoryLayout(
            (o_instr_shape[0], o_instr_shape[1] // self.SPLIT_D_FACTOR), col_stride=1)
        qk_tmem_ty: gl.constexpr = tensor_memory_descriptor_type(gl.float32, self.qk_shape, self.qk_tmem_layout,
                                                                 self.qk_shape)
        o_splitn_tmem_ty: gl.constexpr = tensor_memory_descriptor_type(
            gl.float32,
            [self.o_shape[0], self.o_shape[1] // self.SPLIT_D_FACTOR],
            o_splitn_tmem_layout,
            self.o_shape,
        )

        self.qk_layout = gl.constexpr(qk_tmem_ty.get_reg_layout(num_warps=self.num_warps,
                                                                instr_variant="32x32b_splitn"))
        self.o_splitn_layout = gl.constexpr(o_splitn_tmem_ty.get_reg_layout(num_warps=self.num_warps))
        self.alpha_2d_layout = gl.constexpr(gl.BlockedLayout([1, 1], [32, 1], [self.num_warps, 1], [0, 1]))

        self.num_kv_buffers = gl.constexpr(NUM_KV_BUFFERS)
        self.use_exp2_turnstile = gl.constexpr(USE_EXP2_TURNSTILE)

    @gluon.jit
    def get_program(self, pid_m, pid_n):
        start_m = pid_m
        off_hz = pid_n
        off_z = off_hz // self.H
        off_h = off_hz % self.H

        offset_y = off_z * (self.N_CTX * self.H) + off_h * self.N_CTX
        qo_offset_y = offset_y + start_m * self.BLOCK_M

        return AttentionProgram(self, start_m, off_hz, offset_y, qo_offset_y)


@gluon.aggregate
class ProgramScheduler:
    config: AttentionConfig
    start_pid: gl.tensor
    num_pid_n: gl.tensor
    num_pid_in_group: gl.tensor
    num_tiles: gl.tensor

    @gluon.jit
    def create(config):
        start_pid = gl.program_id(0)
        num_pid_m = gl.cdiv(config.N_CTX, config.BLOCK_M)
        num_pid_n = config.Z * config.H
        num_pid_in_group = num_pid_m * config.GROUP_SIZE_N
        num_tiles = num_pid_m * num_pid_n
        return ProgramScheduler(config, start_pid, num_pid_n, num_pid_in_group, num_tiles)

    @gluon.jit
    def get_program(self, tile_id):
        group_id = tile_id // self.num_pid_in_group
        first_pid_n = group_id * self.config.GROUP_SIZE_N
        group_size_n = min(self.num_pid_n - first_pid_n, self.config.GROUP_SIZE_N)
        pid_n = first_pid_n + (tile_id % group_size_n)
        pid_m = (tile_id % self.num_pid_in_group) // group_size_n
        return self.config.get_program(pid_m, pid_n)


@gluon.aggregate
class AttentionProgram:
    config: AttentionConfig
    start_m: gl.tensor
    off_hz: gl.tensor
    offset_y: gl.tensor
    qo_offset_y: gl.tensor

    @gluon.jit
    def get_fused_loop_bounds(self, STAGE: gl.constexpr):
        BLOCK_M: gl.constexpr = self.config.BLOCK_M
        if STAGE == 1:
            return 0, self.config.N_CTX
        elif STAGE == 2:
            return self.start_m * BLOCK_M, (self.start_m + 1) * BLOCK_M
        elif STAGE == 3:
            return 0, (self.start_m + 1) * BLOCK_M
        else:
            return 0, 0

    @gluon.jit
    def get_loop_bounds(self, STAGE: gl.constexpr):
        BLOCK_M: gl.constexpr = self.config.BLOCK_M
        if STAGE == 1:
            lo, hi = 0, self.start_m * BLOCK_M
        elif STAGE == 2:
            lo, hi = self.start_m * BLOCK_M, (self.start_m + 1) * BLOCK_M
        else:
            lo, hi = 0, self.config.N_CTX
        return lo, hi


# ===-----------------------------------------------------------------------===#
# _gluon_attn
# ===-----------------------------------------------------------------------===#


@gluon.jit
def _borrow_s_as_p(config, s_tmem):
    p_tmem = s_tmem.slice(0, config.BLOCK_N // 2)
    return p_tmem._reinterpret(config.dtype, config.qk_shape, config.p_tmem_layout)


@gluon.jit
def _borrow_s_as_alpha(config, s_tmem):
    alpha_tmem = s_tmem.slice(config.BLOCK_N // 2, 1)
    alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], col_stride=1)
    return alpha_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], alpha_layout)


@gluon.jit
def _borrow_s_for_epilogue(config, s_tmem):
    m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1)
    l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1)
    layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], col_stride=1)
    m_i_tmem = m_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
    l_i_tmem = l_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
    return m_i_tmem, l_i_tmem


@gluon.constexpr_function
def _get_split_n_layout(layout: gl.constexpr, SPLIT_FACTOR: gl.constexpr = 2):
    assert isinstance(layout, gl.DistributedLinearLayout), "split_n requires a distributed layout"
    assert SPLIT_FACTOR == 1 or SPLIT_FACTOR == 2, "split_n requires a split factor of 1 or 2"
    if SPLIT_FACTOR == 1:
        return layout
    else:
        target = [0, layout.shape[1] // 2]  # [0, 2^{m-1}]
        last_reg_idx = len(layout.reg_bases) - 1
        reg_last = layout.reg_bases[last_reg_idx]

        if reg_last == target:
            return layout

        ret = copy.deepcopy(layout)

        # Find [0, 2^{m-1}] across lists and swap it with last reg
        for L in (ret.reg_bases, ret.lane_bases, ret.warp_bases, ret.block_bases):
            for i, b in enumerate(L):
                if b == target:
                    L[i], ret.reg_bases[last_reg_idx] = reg_last, target
                    return ret
        assert False, f"split_n requires having a basis {target}. Got\n{layout}"


@gluon.jit
def _split_n(x, SPLIT_FACTOR: gl.constexpr = 2):
    if SPLIT_FACTOR == 1:
        return (x, )
    else:
        layout: gl.constexpr = _get_split_n_layout(x.type.layout)
        x0, x1 = x.reshape([x.shape[0], 2, x.shape[1] // 2]).permute(0, 2, 1).split()
        x0 = gl.convert_layout(x0, layout, assert_trivial=True)
        x1 = gl.convert_layout(x1, layout, assert_trivial=True)
        return _split_n(x0, SPLIT_FACTOR // 2) + _split_n(x1, SPLIT_FACTOR // 2)


@gluon.constexpr_function
def _get_join_n_layout(layout, SPLIT_FACTOR: gl.constexpr = 2):
    assert isinstance(layout, gl.DistributedLinearLayout), "join_n requires a Linear layout"
    shape = list(layout.shape)
    regs = [[0, shape[1] * (1 << i)] for i in range(int(math.log2(SPLIT_FACTOR)))]
    shape[1] *= SPLIT_FACTOR
    return gl.DistributedLinearLayout(
        layout.reg_bases + regs,
        layout.lane_bases,
        layout.warp_bases,
        layout.block_bases,
        shape,
    )


@gluon.jit
def _join_n(xs):
    if len(xs) == 1:
        return xs[0]
    else:
        x0 = _join_n(xs[:len(xs) // 2])
        x1 = _join_n(xs[len(xs) // 2:])
        layout: gl.constexpr = _get_join_n_layout(x0.type.layout)
        x = gl.join(x0, x1).permute(0, 2, 1).reshape([x0.shape[0], x0.shape[1] * 2])
        return gl.convert_layout(x, layout, assert_trivial=True)


@gluon.jit
def _attn_fwd_load(config, chnls, descs, M, STAGE: gl.constexpr):
    q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
    desc_q, desc_k, desc_v, desc_o = descs

    q_producer = q_chnl.create_producer()
    kv_producer = kv_chnl.create_producer()

    scheduler = ProgramScheduler.create(config)
    for pid in range(scheduler.start_pid, scheduler.num_tiles, config.NUM_SMS):
        prog = scheduler.get_program(pid)
        lo, hi = prog.get_fused_loop_bounds(STAGE)

        q0_offset = prog.qo_offset_y + config.SPLIT_M * 0
        q0_smem, q0_bar, q_producer = q_producer.acquire()
        issue_async_tma_load(q0_smem, q0_bar, desc_q, q0_offset)

        offsetkv_y = prog.offset_y + lo
        k_smem, k_bar, kv_producer = kv_producer.acquire()
        issue_async_tma_load(k_smem, k_bar, desc_k, offsetkv_y)

        q1_offset = prog.qo_offset_y + config.SPLIT_M * 1
        q1_smem, q1_bar, q_producer = q_producer.acquire()
        issue_async_tma_load(q1_smem, q1_bar, desc_q, q1_offset)

        v_smem, v_bar, kv_producer = kv_producer.acquire()
        issue_async_tma_load(v_smem, v_bar, desc_v, offsetkv_y)

        for start_n in range(lo + config.BLOCK_N, hi, config.BLOCK_N):
            offsetkv_y = prog.offset_y + start_n
            k_smem, k_bar, kv_producer = kv_producer.acquire()
            issue_async_tma_load(k_smem, k_bar, desc_k, offsetkv_y)
            v_smem, v_bar, kv_producer = kv_producer.acquire()
            issue_async_tma_load(v_smem, v_bar, desc_v, offsetkv_y)


@gluon.jit
def _attn_fwd_mma(config, chnls, descs, M, STAGE: gl.constexpr):
    q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
    desc_q, desc_k, desc_v, desc_o = descs

    q_consumer = q_chnl.create_consumer()
    kv_consumer = kv_chnl.create_consumer()
    o_producer = o_chnl.create_producer()

    s0_producer = s0_chnl.create_producer()
    s1_producer = s1_chnl.create_producer()

    scheduler = ProgramScheduler.create(config)
    for pid in range(scheduler.start_pid, scheduler.num_tiles, config.NUM_SMS):
        prog = scheduler.get_program(pid)
        lo, hi = prog.get_fused_loop_bounds(STAGE)
        num_mmas = (hi - lo) // config.BLOCK_N

        q0_smem, q0_bar, q_consumer = q_consumer.acquire()
        k_smem, k_bar, kv_consumer = kv_consumer.acquire()
        s0_tmem, s0_bar, s0_producer = s0_producer.acquire()
        tcgen05_mma(q0_smem, k_smem.permute((1, 0)), s0_tmem, use_acc=False, mbarriers=[s0_bar])

        q1_smem, q1_bar, q_consumer = q_consumer.acquire()
        s1_tmem, s1_bar, s1_producer = s1_producer.acquire()
        tcgen05_mma(q1_smem, k_smem.permute((1, 0)), s1_tmem, use_acc=False, mbarriers=[s1_bar, k_bar])

        v_smem, v_bar, kv_consumer = kv_consumer.acquire()
        o0_tmem, o0_bar, o_producer = o_producer.acquire()
        s0_tmem, s0_bar, s0_producer = s0_producer.acquire()
        p0_tmem = _borrow_s_as_p(config, s0_tmem)
        tcgen05_mma(p0_tmem, v_smem, o0_tmem, use_acc=False, mbarriers=[o0_bar])
        o1_init = False

        for _ in range(num_mmas - 1):
            k_smem, k_bar, kv_consumer = kv_consumer.acquire()
            tcgen05_mma(q0_smem, k_smem.permute((1, 0)), s0_tmem, use_acc=False, mbarriers=[s0_bar])

            o1_tmem, o1_bar, o_producer = o_producer.acquire()
            s1_tmem, s1_bar, s1_producer = s1_producer.acquire()
            p1_tmem = _borrow_s_as_p(config, s1_tmem)
            tcgen05_mma(p1_tmem, v_smem, o1_tmem, use_acc=o1_init, mbarriers=[o1_bar, v_bar])
            o1_init = True

            tcgen05_mma(q1_smem, k_smem.permute((1, 0)), s1_tmem, use_acc=False, mbarriers=[s1_bar, k_bar])

            v_smem, v_bar, kv_consumer = kv_consumer.acquire()
            o0_tmem, o0_bar, o_producer = o_producer.acquire()
            s0_tmem, s0_bar, s0_producer = s0_producer.acquire()
            p0_tmem = _borrow_s_as_p(config, s0_tmem)
            tcgen05_mma(p0_tmem, v_smem, o0_tmem, mbarriers=[o0_bar])

        tcgen05_commit(q0_bar)
        tcgen05_commit(q1_bar)

        o1_tmem, o1_bar, o_producer = o_producer.acquire()
        s1_tmem, s1_bar, s1_producer = s1_producer.acquire()
        p1_tmem = _borrow_s_as_p(config, s1_tmem)
        tcgen05_mma(p1_tmem, v_smem, o1_tmem, use_acc=o1_init, mbarriers=[o1_bar, v_bar, s0_bar, s1_bar])


@gluon.jit
def _mask_scalar(qk, col_limit_right, s, i):
    col_lim_right_s = col_limit_right - s
    col_lim_right_cur = max(col_lim_right_s, 0)
    mask = -1 << col_lim_right_cur
    mask_i_bit = (mask & (1 << i)) == 0
    return gl.where(mask_i_bit, qk, -float("inf"))


@gluon.jit
def _apply_causal_mask(qk, col_limit_right):
    # Apply causal mask via a bitmask calculated for each block of 16 elements.
    # This allows the efficient R2P (register to predicate) instruction to be used at the SASS level.
    # Credit to Tri Dao,
    # https://github.com/Dao-AILab/flash-attention/commit/bac1001e4f6caa09d70537495d6746a685a2fa78
    #
    # NOTE: We use map_elementiwse here in order to generate an interleaved sequence of instructions
    # that processes one element of qk at a time. This improves ptxas's resulting SASS.
    offs_n = gl.arange(0, qk.shape[1])[None, :]
    s = offs_n & ~0xf
    i = offs_n & 0xf
    return gl.map_elementwise(_mask_scalar, qk, col_limit_right, s, i)


@gluon.jit
def _compute_and_store_exp2(config, qk, p_tmem):
    SIZE: gl.constexpr = p_tmem.shape[1] // config.SPLIT_EXP_FACTOR
    qks = _split_n(qk, config.SPLIT_EXP_FACTOR)
    ps = ()
    for i in gl.static_range(config.SPLIT_EXP_FACTOR):
        p = gl.exp2(qks[i])
        p_tmem.slice(i * SIZE, SIZE).store(p.to(config.dtype))
        ps = ps + (p, )
    return _join_n(ps)


@gluon.jit
def _subtiled_qk_load(config, s_tmem, use_tmem_red: gl.constexpr):
    SIZE: gl.constexpr = s_tmem.shape[1] // config.SPLIT_QK_LOAD_FACTOR
    qks = ()
    if use_tmem_red:
        red_total = None
        for i in gl.static_range(config.SPLIT_QK_LOAD_FACTOR):
            vals, reds = s_tmem.slice(i * SIZE, SIZE).load_max()
            red_total = reds if red_total is None else gl.maximum(red_total, reds)
            qks = qks + (vals, )
        return _join_n(qks), red_total
    else:
        for i in gl.static_range(config.SPLIT_QK_LOAD_FACTOR):
            qks = qks + (s_tmem.slice(i * SIZE, SIZE).load(), )
        return _join_n(qks), None


@gluon.jit
def _softmax_inner_loop(tile_id: gl.constexpr, config, prog,  #
                        s_consumer, corr_producer, exp_turnstile, corr_bar,  #
                        offs_m, m_i, l_i, STAGE: gl.constexpr, use_tmem_red: gl.constexpr):
    lo, hi = prog.get_loop_bounds(STAGE)

    for start_n in range(lo, hi, config.BLOCK_N):
        s_tmem, s_bar, s_consumer = s_consumer.acquire()
        qk, qk_max = _subtiled_qk_load(config, s_tmem, use_tmem_red)

        if STAGE == 2:
            col_limit_right = (offs_m - start_n + 1)[:, None]
            qk = _apply_causal_mask(qk, col_limit_right)

        if use_tmem_red:
            qk_max = gl.convert_layout(qk_max, m_i.type.layout)
            m_ij = gl.maximum(m_i, qk_max * config.qk_scale)
        else:
            m_ij = gl.maximum(m_i, gl.max(qk, 1) * config.qk_scale)
        alpha = gl.exp2(m_i - m_ij)

        alpha_tmem = _borrow_s_as_alpha(config, s_tmem)
        alpha_tmem.store(gl.convert_layout(alpha.expand_dims(1), config.alpha_2d_layout))
        mbarrier.arrive(corr_bar, count=1)

        rowmax = float2.pack(-m_ij[:, None].broadcast_to(qk.shape), axis=1)
        qk = float2.pack(qk, axis=1)
        qk = float2.fma(qk, float2.full_like(qk, config.qk_scale), rowmax)
        qk = float2.unpack(qk, axis=1)

        # Force the softmax partitions to take turns in the EX2 section. This
        # prevents contention for the EX2 unit and improves utilization.
        if config.use_exp2_turnstile:
            _, exp_bar, exp_turnstile = exp_turnstile.acquire()

        # FIXME: When using FADD2 reductions, ptxas misbehaves and spills far
        # below the register limit in the FADD2, FMUL2, EX2 section. Subtile by
        # 4 to minimize the spilling.
        p_tmem = _borrow_s_as_p(config, s_tmem)
        p = _compute_and_store_exp2(config, qk, p_tmem)

        mbarrier.arrive(s_bar, count=1)
        _, corr_bar, corr_producer = corr_producer.acquire()

        if config.use_exp2_turnstile:
            mbarrier.arrive(exp_bar, count=1)

        l_ij = float2.pack2(*_split_n(p)).sum(axis=1)
        l_ij = Float2Tensor(gl.convert_layout(l_ij.value, l_i.value.type.layout, assert_trivial=True))
        alpha = gl.convert_layout(alpha, l_i.value.type.layout, assert_trivial=True)
        l_i = float2.fma(l_i, float2.pack2(alpha, alpha), l_ij)
        m_i = m_ij

    return m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile


@gluon.jit
def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr,  #
                  s_chnl, corr_chnl, exp_turnstile, use_tmem_red: gl.constexpr):
    qk_slice_dim1: gl.constexpr = gl.SliceLayout(1, config.qk_layout)
    sum_layout: gl.constexpr = _get_split_n_layout(config.qk_layout)

    s_consumer = s_chnl.create_consumer()
    corr_producer = corr_chnl.create_producer()
    _, corr_bar, corr_producer = corr_producer.acquire()

    scheduler = ProgramScheduler.create(config)
    for pid in range(scheduler.start_pid, scheduler.num_tiles, config.NUM_SMS):
        prog = scheduler.get_program(pid)

        offs_m = prog.start_m * config.BLOCK_M
        offs_m += gl.arange(tile_id * config.SPLIT_M, (1 + tile_id) * config.SPLIT_M)

        m_i = gl.full([config.SPLIT_M], -float("inf"), gl.float32, qk_slice_dim1)
        # Accumulate into 2 row-sums so the reduction can be performed with FADD2.
        l_i = gl.full([config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout))
        l_i = float2.pack2(l_i, l_i)

        if STAGE & 1:
            m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile = _softmax_inner_loop(  #
                tile_id, config, prog, s_consumer, corr_producer, exp_turnstile, corr_bar,  #
                offs_m, m_i, l_i, STAGE=4 - STAGE, use_tmem_red=use_tmem_red)
        if STAGE & 2:
            m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile = _softmax_inner_loop(  #
                tile_id, config, prog, s_consumer, corr_producer, exp_turnstile, corr_bar,  #
                offs_m, m_i, l_i, STAGE=2, use_tmem_red=use_tmem_red)
        l_i0, l_i1 = float2.unpack2(l_i)
        l_i = l_i0 + l_i1

        s_tmem, s_bar, s_consumer = s_consumer.acquire()
        m_i_tmem, l_i_tmem = _borrow_s_for_epilogue(config, s_tmem)
        m_i_tmem.store(gl.convert_layout(m_i.expand_dims(1), config.alpha_2d_layout))
        l_i_tmem.store(gl.convert_layout(l_i.expand_dims(1), config.alpha_2d_layout))

        mbarrier.arrive(corr_bar, count=1)
        _, corr_bar, corr_producer = corr_producer.acquire()

        mbarrier.arrive(s_bar, count=1)


@gluon.jit
def _attn_fwd_softmax0(config, chnls, descs, M, STAGE: gl.constexpr, use_tmem_red: gl.constexpr):
    q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
    desc_q, desc_k, desc_v, desc_o = descs
    _softmax_tile(0, config, M, desc_o, STAGE, s0_chnl, c0_chnl, exp_turnstile.create_producer(), use_tmem_red)


@gluon.jit
def _attn_fwd_softmax1(config, chnls, descs, M, STAGE: gl.constexpr, use_tmem_red: gl.constexpr):
    q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
    desc_q, desc_k, desc_v, desc_o = descs
    _softmax_tile(1, config, M, desc_o, STAGE, s1_chnl, c1_chnl, exp_turnstile.create_consumer(), use_tmem_red)


@gluon.jit
def _attn_fwd_epilogue(config, chnls, descs, M, STAGE: gl.constexpr):
    q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
    desc_q, desc_k, desc_v, desc_o = descs

    epi_consumer = epi_chnl.create_consumer()
    scheduler = ProgramScheduler.create(config)
    for pid in range(scheduler.start_pid, scheduler.num_tiles, config.NUM_SMS):
        prog = scheduler.get_program(pid)

        o0_smem, o0_bar, epi_consumer = epi_consumer.acquire()
        tma.async_copy_shared_to_global(desc_o, [prog.qo_offset_y + config.SPLIT_M * 0, 0], o0_smem)

        o1_smem, o1_bar, epi_consumer = epi_consumer.acquire()
        tma.async_copy_shared_to_global(desc_o, [prog.qo_offset_y + config.SPLIT_M * 1, 0], o1_smem)

        tma.store_wait(1)
        mbarrier.arrive(o0_bar, count=1)
        tma.store_wait(0)
        mbarrier.arrive(o1_bar, count=1)


@gluon.jit
def _attn_fwd_correction_rescale(config, s_tmem, corr_consumer, o_consumer):
    alpha_layout: gl.constexpr = gl.SliceLayout(1, config.o_splitn_layout)

    o_tmem, o_bar, o_consumer = o_consumer.acquire()

    _, corr_bar, corr_consumer = corr_consumer.acquire()
    alpha = _borrow_s_as_alpha(config, s_tmem).load(config.alpha_2d_layout)
    mbarrier.arrive(corr_bar, count=1)
    alpha = gl.convert_layout(alpha.reshape([config.SPLIT_M]), alpha_layout)

    alpha = float2.pack(alpha[:, None].broadcast_to(config.o_shape[0], config.SPLIT_D), axis=1)
    for i in gl.static_range(config.SPLIT_D_FACTOR):
        o_ref = o_tmem.slice(i * config.SPLIT_D, config.SPLIT_D)
        o = float2.pack(o_ref.load(config.o_splitn_layout), axis=1)
        o = o * alpha
        o_ref.store(float2.unpack(o, axis=1))
    mbarrier.arrive(o_bar, count=1)
    return corr_consumer, o_consumer


@gluon.jit
def _attn_fwd_correction_epilogue(config, prog, s_tmem, M, corr_consumer, epi_producer, o_consumer):
    alpha_layout: gl.constexpr = gl.SliceLayout(1, config.o_splitn_layout)

    _, corr_bar, corr_consumer = corr_consumer.acquire()
    m_i_tmem, l_i_tmem = _borrow_s_for_epilogue(config, s_tmem)
    m_i = m_i_tmem.load(config.alpha_2d_layout).reshape([config.SPLIT_M])
    m_i = gl.convert_layout(m_i, alpha_layout)
    l_i = l_i_tmem.load(config.alpha_2d_layout).reshape([config.SPLIT_M])
    l_i = gl.convert_layout(l_i, alpha_layout)
    mbarrier.arrive(corr_bar, count=1)

    o_smem, epi_bar, epi_producer = epi_producer.acquire()
    o_tmem, o_bar, o_consumer = o_consumer.acquire()

    # Shared memory subtile size is limited by the swizzle byte size.
    contigDimSize: gl.constexpr = o_smem.type.layout.swizzle_byte_width * 8 // o_smem.type.element_ty.primitive_bitwidth
    if o_smem.type.shape[1] // config.SPLIT_D_FACTOR >= contigDimSize:
        SPLIT_N_FACTOR: gl.constexpr = config.SPLIT_D_FACTOR
    else:
        SPLIT_N_FACTOR: gl.constexpr = 1
    gl.static_assert(o_smem.type.shape[1] // SPLIT_N_FACTOR >= contigDimSize,
                     "Block shape is too small for the swizzle byte size in NVMMA Shared Layout")
    SPLIT_N: gl.constexpr = o_smem.type.shape[1] // SPLIT_N_FACTOR

    scale = float2.pack((1 / l_i)[:, None].broadcast_to(config.o_shape[0], SPLIT_N), axis=1)
    for i in gl.static_range(SPLIT_N_FACTOR):
        o_ref = o_tmem.slice(i * SPLIT_N, SPLIT_N)
        o = float2.pack(o_ref.load(config.o_splitn_layout), axis=1)
        o = o * scale
        o_smem.slice(i * SPLIT_N, SPLIT_N, dim=1).store(float2.unpack(o, axis=1).to(config.dtype))

    fence_async_shared()
    mbarrier.arrive(epi_bar, count=1)
    mbarrier.arrive(o_bar, count=1)

    m_i += gl.log2(l_i)
    coalesced: gl.constexpr = gl.BlockedLayout([1], [32], [config.num_warps], [0])
    offs_m = prog.start_m * config.BLOCK_M
    offs_m += gl.arange(0 * config.SPLIT_M, 1 * config.SPLIT_M, coalesced)
    m_ptrs = M + prog.off_hz * config.N_CTX + offs_m
    gl.store(m_ptrs, gl.convert_layout(m_i, coalesced))

    return corr_consumer, epi_producer, o_consumer


@gluon.jit
def _attn_fwd_correction(config, chnls, descs, M, STAGE: gl.constexpr):
    q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls

    s0_tmem = s0_chnl.mem.index(0)
    s1_tmem = s1_chnl.mem.index(0)
    corr0_consumer = c0_chnl.create_consumer()
    corr1_consumer = c1_chnl.create_consumer()
    o_consumer = o_chnl.create_consumer()

    epi_producer = epi_chnl.create_producer()

    scheduler = ProgramScheduler.create(config)
    for pid in range(scheduler.start_pid, scheduler.num_tiles, config.NUM_SMS):
        prog = scheduler.get_program(pid)
        lo, hi = prog.get_fused_loop_bounds(STAGE)
        num_corrections = (hi - lo) // config.BLOCK_N

        _, corr0_bar, corr0_consumer = corr0_consumer.acquire()
        mbarrier.arrive(corr0_bar, count=1)
        _, corr1_bar, corr1_consumer = corr1_consumer.acquire()
        mbarrier.arrive(corr1_bar, count=1)

        for i in range(num_corrections - 1):
            corr0_consumer, o_consumer = _attn_fwd_correction_rescale(config, s0_tmem, corr0_consumer, o_consumer)
            corr1_consumer, o_consumer = _attn_fwd_correction_rescale(config, s1_tmem, corr1_consumer, o_consumer)

        corr0_consumer, epi_producer, o_consumer = _attn_fwd_correction_epilogue(  #
            config, prog, s0_tmem, M, corr0_consumer, epi_producer, o_consumer)
        corr1_consumer, epi_producer, o_consumer = _attn_fwd_correction_epilogue(  #
            config, prog, s1_tmem, M, corr1_consumer, epi_producer, o_consumer)


def attention_repr(specialization):
    name = "gluon_attention"
    # Up to 150 TFLOPS faster for fp8!
    if specialization.constants["dtype"] == gl.float8e5:
        name = "cutlass_" + name
    return name


@gluon.jit(do_not_specialize=["Z", "H", "N_CTX"], repr=attention_repr)
def attention_kernel(  #
        sm_scale, M, Z, H, N_CTX, desc_q, desc_k, desc_v, desc_o,  #
        BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr, HEAD_DIM: gl.constexpr,  #
        GROUP_SIZE_N: gl.constexpr, NUM_SMS: gl.constexpr, STAGE: gl.constexpr, SPLIT_EXP_FACTOR: gl.constexpr,  #
        dtype: gl.constexpr, num_warps: gl.constexpr, use_tmem_red: gl.constexpr, NUM_KV_BUFFERS: gl.constexpr,
        USE_EXP2_TURNSTILE: gl.constexpr):
    qk_scale = sm_scale * 1.44269504
    config = AttentionConfig(qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, STAGE,
                             SPLIT_EXP_FACTOR,  #
                             dtype, num_warps, NUM_KV_BUFFERS, USE_EXP2_TURNSTILE)

    q_chnl = get_desc_channel(desc_q, num_buffers=2)
    kv_chnl = get_desc_channel(desc_k, num_buffers=config.num_kv_buffers)
    o_chnl = TensorMemoryChannel.alloc(config.o_shape, gl.float32, config.o_tmem_layout, num_buffers=2)
    epi_chnl = SharedMemoryChannel.alloc(config.o_shape, config.dtype, gl.constexpr(desc_o.layout), num_buffers=2)
    s0_chnl = TensorMemoryChannel.alloc(config.qk_shape, gl.float32, config.qk_tmem_layout, num_buffers=1)
    s1_chnl = TensorMemoryChannel.alloc(config.qk_shape, gl.float32, config.qk_tmem_layout, num_buffers=1)
    c0_chnl = SharedMemoryChannel.alloc([1], gl.int8, gl.constexpr(mbarrier.MBarrierLayout()), num_buffers=1)
    c1_chnl = SharedMemoryChannel.alloc([1], gl.int8, gl.constexpr(mbarrier.MBarrierLayout()), num_buffers=1)
    exp_turnstile = SharedMemoryChannel.alloc([1], gl.int8, gl.constexpr(mbarrier.MBarrierLayout()), num_buffers=1)

    chnls = (q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile)
    descs = (desc_q, desc_k, desc_v, desc_o)
    gl.warp_specialize([
        (_attn_fwd_correction, (config, chnls, descs, M, STAGE)),
        (_attn_fwd_softmax0, (config, chnls, descs, M, STAGE, use_tmem_red)),
        (_attn_fwd_softmax1, (config, chnls, descs, M, STAGE, use_tmem_red)),
        (_attn_fwd_mma, (config, chnls, descs, M, STAGE)),
        (_attn_fwd_load, (config, chnls, descs, M, STAGE)),
        (_attn_fwd_epilogue, (config, chnls, descs, M, STAGE)),
    ], [4, 4, 1, 1, 1], [192, 192, 24, 24, 24])

    q_chnl.release()
    kv_chnl.release()
    o_chnl.release()
    epi_chnl.release()
    s0_chnl.release()
    s1_chnl.release()
    c0_chnl.release()
    c1_chnl.release()
    exp_turnstile.release()


# ===-----------------------------------------------------------------------===#
# Entry Point
# ===-----------------------------------------------------------------------===#


def is_cuda():
    return triton.runtime.driver.active.get_current_target().backend == "cuda"


def is_blackwell():
    return is_cuda() and torch.cuda.get_device_capability()[0] == 10


def is_blackwell_ultra():
    return is_cuda() and torch.cuda.get_device_capability()[0:2] == (10, 3)


@dataclass(frozen=True, slots=True)
class KernelConfig:
    BLOCK_M: int = 256
    BLOCK_N: int = 128
    GROUP_SIZE_N: int | None = None
    SPLIT_EXP_FACTOR: int | None = None
    NUM_WARPS: int = 4
    MAXNREG: int = 128
    OCCUPANCY: int = 1
    USE_TMEM_RED: bool = False
    NUM_KV_BUFFERS: int | None = None
    USE_EXP2_TURNSTILE: bool | None = None


def _default_split_exp_factor(head_dim: int) -> int:
    return max(1, 256 // head_dim)


def _default_num_kv_buffers(head_dim: int, dtype: torch.dtype) -> int:
    is_fp16 = dtype in [torch.float16, torch.bfloat16]
    if is_fp16:
        return 3 if head_dim == 128 else 6
    return 4 if head_dim == 128 else 8


def select_kernel_config(
    head_dim: int,
    n_ctx: int,
    dtype: torch.dtype,
    causal: bool,
    use_tmem_red: bool,
    override: KernelConfig | None = None,
) -> KernelConfig:
    is_fp8 = dtype == torch.float8_e5m2
    is_bf16 = dtype == torch.bfloat16
    is_bwu = is_blackwell_ultra()

    block_m = 256
    block_n = 128
    group_size_n = 1
    split_exp_factor = _default_split_exp_factor(head_dim)
    num_warps = 4
    maxnreg = 128
    occupancy = 1
    use_selected_tmem_red = (use_tmem_red or (is_bwu and not causal)) and not causal
    num_kv_buffers = _default_num_kv_buffers(head_dim, dtype)
    use_exp2_turnstile = head_dim == 64

    if causal:
        group_size_n = 8 if head_dim == 64 or n_ctx <= 2048 else 4

    if head_dim == 128:
        split_exp_factor = 4
        if not causal and is_bf16 and n_ctx <= 2048:
            group_size_n = 4
    elif not causal and head_dim == 64 and use_selected_tmem_red:
        split_exp_factor = 1
        if n_ctx <= 1024:
            num_kv_buffers = 2
        elif n_ctx >= 8192:
            maxnreg = 112
    elif causal and head_dim == 64:
        num_kv_buffers = 2
        if n_ctx <= 1024:
            split_exp_factor = 2
        else:
            use_exp2_turnstile = False

    if is_fp8:
        if causal and head_dim == 64:
            group_size_n = 8 if n_ctx <= 2048 else 4
            split_exp_factor = 4 if n_ctx <= 2048 else 2
            maxnreg = 112 if n_ctx >= 4096 else 128
            use_selected_tmem_red = False
            num_kv_buffers = 2
            use_exp2_turnstile = n_ctx <= 1024
        elif causal and head_dim == 128:
            group_size_n = 8 if n_ctx <= 2048 else 4
            split_exp_factor = 2 if n_ctx <= 2048 else 8
            maxnreg = 128
            use_selected_tmem_red = False
            num_kv_buffers = 4
            use_exp2_turnstile = False
        elif not causal and head_dim == 64:
            group_size_n = 1
            split_exp_factor = 2
            maxnreg = 128
            use_selected_tmem_red = is_bwu
            num_kv_buffers = 2 if n_ctx <= 1024 else 8
            use_exp2_turnstile = True
        elif not causal and head_dim == 128:
            group_size_n = 1
            split_exp_factor = 4 if n_ctx <= 2048 else 8
            maxnreg = 128
            use_selected_tmem_red = is_bwu
            num_kv_buffers = 4
            use_exp2_turnstile = False
        else:
            group_size_n = 4 if causal else 1
            split_exp_factor = _default_split_exp_factor(head_dim)
            use_selected_tmem_red = use_tmem_red and not causal

    config = KernelConfig(
        BLOCK_M=block_m,
        BLOCK_N=block_n,
        GROUP_SIZE_N=group_size_n,
        SPLIT_EXP_FACTOR=split_exp_factor,
        NUM_WARPS=num_warps,
        MAXNREG=maxnreg,
        OCCUPANCY=occupancy,
        USE_TMEM_RED=use_selected_tmem_red,
        NUM_KV_BUFFERS=num_kv_buffers,
        USE_EXP2_TURNSTILE=use_exp2_turnstile,
    )
    if override is None:
        return config

    values = {field.name: getattr(override, field.name) for field in fields(KernelConfig)}
    values = {name: getattr(config, name) if value is None else value for name, value in values.items()}
    return KernelConfig(**values)


def torch_dtype_to_triton(dtype):
    if dtype == torch.float8_e5m2:
        return gl.float8e5
    return getattr(gl, str(dtype).split('.')[1])


def make_tensor_desc(x, shape, strides, block_shape):
    layout = gl.NVMMASharedLayout.get_default_for(block_shape, torch_dtype_to_triton(x.dtype))
    return TensorDescriptor(x, shape=shape, strides=strides, block_shape=block_shape, layout=layout)


def attention_forward(q, k, v, causal, sm_scale, o=None, M=None, *, use_tmem_red=False, p: KernelConfig | None = None):
    if isinstance(o, bool) and M is None and use_tmem_red is False:
        use_tmem_red = o
        o = None

    HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
    HEAD_DIM_V = v.shape[-1]
    assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
    assert HEAD_DIM_K in {16, 32, 64, 128, 256}

    stage = 3 if causal else 1
    p = select_kernel_config(HEAD_DIM_K, q.shape[2], q.dtype, causal, use_tmem_red, override=p)

    if o is None:
        o = torch.empty_like(q)
    if M is None:
        M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)

    y_dim = q.shape[0] * q.shape[1] * q.shape[2]

    # The kernel will split BLOCK_M into two subtiles.
    BLOCK_M = p.BLOCK_M
    BLOCK_N = p.BLOCK_N
    SPLIT_M = BLOCK_M // 2
    GROUP_SIZE_N = p.GROUP_SIZE_N
    NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count * p.OCCUPANCY

    desc_q = make_tensor_desc(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[SPLIT_M, HEAD_DIM_K])
    desc_v = make_tensor_desc(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_N, HEAD_DIM_K])
    desc_k = make_tensor_desc(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_N, HEAD_DIM_K])
    desc_o = make_tensor_desc(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[SPLIT_M, HEAD_DIM_K])

    num_pid_m = triton.cdiv(q.shape[2], BLOCK_M)
    num_pid_n = q.shape[0] * q.shape[1]
    grid = min(NUM_SMS, num_pid_m * num_pid_n)

    attention_kernel[(grid, )](
        sm_scale, M, q.shape[0], q.shape[1], q.shape[2],  #
        desc_q, desc_k, desc_v, desc_o,  #
        BLOCK_M, BLOCK_N, HEAD_DIM_K, GROUP_SIZE_N, NUM_SMS,  #
        SPLIT_EXP_FACTOR=p.SPLIT_EXP_FACTOR, STAGE=stage, dtype=torch_dtype_to_triton(q.dtype),  #
        num_warps=p.NUM_WARPS, maxnreg=p.MAXNREG, use_tmem_red=p.USE_TMEM_RED, NUM_KV_BUFFERS=p.NUM_KV_BUFFERS,
        USE_EXP2_TURNSTILE=p.USE_EXP2_TURNSTILE)

    return o, M


# ===-----------------------------------------------------------------------===#
# Unit Tests
# ===-----------------------------------------------------------------------===#


@pytest.mark.parametrize("Z", [1, 4])
@pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("N_CTX", [1024, 2048, 4096, 8192])
@pytest.mark.parametrize("HEAD_DIM", [64, 128])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_tmem_red", [False, True] if is_blackwell_ultra() else [False])
@pytest.mark.skipif(not is_blackwell(), reason="Gluon attention is only supported on Blackwell GPUs")
def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype, use_tmem_red, profile=False):
    device = "cuda"

    def alloc_fn(size: int, alignment: int, stream):
        return torch.empty(size, dtype=torch.int8, device=device)

    triton.set_allocator(alloc_fn)

    if use_tmem_red and not is_blackwell_ultra():
        pytest.skip("TMEM reduction is only supported on Blackwell Ultra GPUs")

    torch.manual_seed(42)
    q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
    k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
    v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
    sm_scale = 0.5

    ref_out = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale, is_causal=causal)

    tri_out, _ = attention_forward(q, k, v, causal, sm_scale, use_tmem_red=use_tmem_red)
    torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)


# ===-----------------------------------------------------------------------===#
# Benchmarking
# ===-----------------------------------------------------------------------===#

BATCH = [4]
N_HEADS = [32]
HEAD_DIM = [64, 128]
causal = [False, True]
providers = ["triton-fp16", "triton-fp8"]
N_CTX = [2**i for i in range(10, 17)]
use_tmem_reds = [False, True] if is_blackwell_ultra() else [False]

bench_configs = []
for Z, H, D, is_causal, use_tmem_red in itertools.product(BATCH, N_HEADS, HEAD_DIM, causal, use_tmem_reds):
    config = triton.testing.Benchmark(
        x_names=["N_CTX"],
        x_vals=N_CTX,
        line_arg="provider",
        line_vals=providers,
        line_names=providers,
        styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")],
        ylabel="TFLOPS",
        plot_name=f"Attention Z={Z} H={H} D={D} causal={is_causal} use_tmem_red={use_tmem_red}",
        args={
            "Z": Z,
            "H": H,
            "HEAD_DIM": D,
            "causal": is_causal,
            "use_tmem_red": use_tmem_red,
        },
    )
    bench_configs.append(config)


@triton.testing.perf_report(bench_configs)
def bench(Z, H, N_CTX, HEAD_DIM, causal, use_tmem_red, provider):
    provider, dtype = provider.split("-")
    if dtype == "fp16":
        dtype = torch.float16
    elif dtype == "bf16":
        dtype = torch.bfloat16
    elif dtype == "fp8":
        dtype = torch.float8_e5m2
    else:
        raise ValueError(f"Unsupported dtype: {dtype}")
    device = "cuda"

    torch.manual_seed(42)
    q = (torch.empty((Z, H, N_CTX, HEAD_DIM), device=device).normal_(mean=0.0, std=0.5).requires_grad_()).to(dtype)
    k = (torch.empty((Z, H, N_CTX, HEAD_DIM), device=device).normal_(mean=0.0, std=0.5).requires_grad_()).to(dtype)
    v = (torch.empty((Z, H, N_CTX, HEAD_DIM), device=device).normal_(mean=0.0, std=0.5).requires_grad_()).to(dtype)
    sm_scale = 1.3

    o = torch.empty_like(q)
    M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)

    with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.CUDNN_ATTENTION]):
        if provider == "triton":
            fn = lambda: attention_forward(q, k, v, causal, sm_scale, o, M, use_tmem_red=use_tmem_red)
        elif provider == "cudnn":
            fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale, is_causal=causal)
        else:
            raise ValueError(f"Unsupported provider: {provider}")

        ms = triton.testing.do_bench_cudagraph(fn)
        flops_per_matmul = 2.0 * Z * H * N_CTX * N_CTX * HEAD_DIM
        total_flops = 2 * flops_per_matmul
        if causal:
            total_flops *= 0.5
        return total_flops * 1e-12 / (ms * 1e-3)


if __name__ == "__main__":
    bench.run(save_path=".", print_data=True)