Attention Forward
This example can be found at python/examples/gluon/01-attention-forward.py.
import copy
import math
import torch
import triton
import pytest
import itertools
from dataclasses import dataclass, fields
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
from triton.experimental.gluon.language.nvidia.hopper import fence_async_shared
from triton.experimental.gluon.language.nvidia.blackwell import (
TensorMemoryLayout,
allocate_tensor_memory,
tensor_memory_descriptor,
tensor_memory_descriptor_type,
tma,
mbarrier,
tcgen05_mma,
tcgen05_commit,
float2,
)
from triton.experimental.gluon.language.nvidia.blackwell.float2 import Float2Tensor
# ===-----------------------------------------------------------------------===#
# Layout Utilities
# ===-----------------------------------------------------------------------===#
@gluon.constexpr_function
def get_mma_instr_shape(shape, element_ty):
m = 128 if shape[0] >= 128 else 64
n = 256 if shape[1] >= 256 else shape[1]
k = 256 // element_ty.primitive_bitwidth
return (m, n, k)
# ===-----------------------------------------------------------------------===#
# Data Abstractions
# ===-----------------------------------------------------------------------===#
@gluon.aggregate
class BarrierCounter:
index: gl.tensor
phase: gl.tensor
num_barriers: gl.constexpr
@gluon.must_use_result
@gluon.jit
def increment(self):
if self.num_barriers == 1:
return BarrierCounter(gl.to_tensor(0), self.phase ^ 1, self.num_barriers)
next_index = self.index + 1
rollover = next_index == self.num_barriers
index = gl.where(rollover, 0, next_index)
phase = gl.where(rollover, self.phase ^ 1, self.phase)
return BarrierCounter(index, phase, self.num_barriers)
def Channel(T, alloc_fn):
@gluon.aggregate
class ChannelType:
mem: T
ready_bars: gl.shared_memory_descriptor
empty_bars: gl.shared_memory_descriptor
num_buffers: gl.constexpr
num_consumers: gl.constexpr
@gluon.jit
def alloc(shape: gl.constexpr, dtype: gl.constexpr, layout: gl.constexpr, num_buffers: gl.constexpr,
num_consumers: gl.constexpr = 1):
mem = alloc_fn(dtype, [num_buffers] + shape, layout)
ready_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
empty_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
for i in gl.static_range(num_buffers):
mbarrier.init(ready_bars.index(i), count=1)
mbarrier.init(empty_bars.index(i), count=num_consumers)
mbarrier.arrive(empty_bars.index(i), count=num_consumers)
return ChannelType(mem, ready_bars, empty_bars, num_buffers, num_consumers)
@gluon.jit
def acquire_producer(self, counter):
index, phase = counter.index, counter.phase
mem = self.mem.index(index)
ready_bar = self.ready_bars.index(index)
empty_bar = self.empty_bars.index(index)
mbarrier.wait(empty_bar, phase)
return mem, ready_bar
@gluon.jit
def acquire_consumer(self, counter):
index, phase = counter.index, counter.phase
mem = self.mem.index(index)
ready_bar = self.ready_bars.index(index)
empty_bar = self.empty_bars.index(index)
mbarrier.wait(ready_bar, phase)
return mem, empty_bar
@gluon.jit
def create_counter(self):
return BarrierCounter(gl.to_tensor(0), gl.to_tensor(0), self.num_buffers)
@gluon.jit
def create_producer(self):
return Producer(self, self.create_counter())
@gluon.jit
def create_consumer(self):
return Consumer(self, self.create_counter())
@gluon.jit
def release(self):
if isinstance(self.mem, gl.shared_memory_descriptor):
self.mem._keep_alive()
for i in gl.static_range(self.num_buffers):
mbarrier.invalidate(self.ready_bars.index(i))
mbarrier.invalidate(self.empty_bars.index(i))
@gluon.aggregate
class Producer:
channel: ChannelType
counter: BarrierCounter
@gluon.jit
def acquire(self):
mem, ready_bar = self.channel.acquire_producer(self.counter)
next = Producer(self.channel, self.counter.increment())
return mem, ready_bar, next
@gluon.aggregate
class Consumer:
channel: ChannelType
counter: BarrierCounter
@gluon.jit
def acquire(self):
mem, empty_bar = self.channel.acquire_consumer(self.counter)
next = Consumer(self.channel, self.counter.increment())
return mem, empty_bar, next
return ChannelType, Producer, Consumer
SharedMemoryChannel, SharedMemoryProducer, SharedMemoryConsumer = Channel(gl.shared_memory_descriptor,
gl.allocate_shared_memory)
TensorMemoryChannel, TensorMemoryProducer, TensorMemoryConsumer = Channel(tensor_memory_descriptor,
allocate_tensor_memory)
@gluon.jit
def get_desc_channel(desc, num_buffers: gl.constexpr, num_consumers: gl.constexpr = 1):
shape: gl.constexpr = desc.block_type.shape
layout: gl.constexpr = desc.layout
return SharedMemoryChannel.alloc(shape, desc.dtype, layout, num_buffers, num_consumers)
@gluon.jit
def issue_async_tma_load(smem, bar, desc, offset):
mbarrier.expect(bar, desc.block_type.nbytes)
tma.async_load(desc, [offset, 0], bar, smem)
# ===-----------------------------------------------------------------------===#
# Gluon Attention
# ===-----------------------------------------------------------------------===#
@gluon.aggregate
class AttentionConfig:
qk_scale: gl.tensor
Z: gl.tensor
H: gl.tensor
N_CTX: gl.tensor
BLOCK_M: gl.constexpr
BLOCK_N: gl.constexpr
HEAD_DIM: gl.constexpr
GROUP_SIZE_N: gl.constexpr
NUM_SMS: gl.constexpr
dtype: gl.constexpr
num_warps: gl.constexpr
SPLIT_D_FACTOR: gl.constexpr
SPLIT_EXP_FACTOR: gl.constexpr
SPLIT_QK_LOAD_FACTOR: gl.constexpr
SPLIT_M: gl.constexpr
SPLIT_D: gl.constexpr
q_shape: gl.constexpr
k_shape: gl.constexpr
v_shape: gl.constexpr
qk_shape: gl.constexpr
o_shape: gl.constexpr
qk_tmem_layout: gl.constexpr
o_tmem_layout: gl.constexpr
p_tmem_layout: gl.constexpr
qk_layout: gl.constexpr
o_splitn_layout: gl.constexpr
alpha_2d_layout: gl.constexpr
num_kv_buffers: gl.constexpr
use_exp2_turnstile: gl.constexpr
@gluon.constexpr_function
def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, STAGE,
SPLIT_EXP_FACTOR, dtype, num_warps, NUM_KV_BUFFERS, USE_EXP2_TURNSTILE):
self.qk_scale = qk_scale
self.Z = Z
self.H = H
self.N_CTX = N_CTX
self.BLOCK_M = gl.constexpr(BLOCK_M)
self.BLOCK_N = gl.constexpr(BLOCK_N)
self.HEAD_DIM = gl.constexpr(HEAD_DIM)
self.GROUP_SIZE_N = gl.constexpr(GROUP_SIZE_N)
self.NUM_SMS = gl.constexpr(NUM_SMS)
self.dtype = gl.constexpr(dtype)
self.num_warps = gl.constexpr(num_warps)
self.SPLIT_D_FACTOR = gl.constexpr(2)
self.SPLIT_EXP_FACTOR = gl.constexpr(SPLIT_EXP_FACTOR)
self.SPLIT_QK_LOAD_FACTOR = gl.constexpr(2 if STAGE == 1 else 1)
self.SPLIT_M = gl.constexpr(self.BLOCK_M // 2)
self.SPLIT_D = gl.constexpr(self.HEAD_DIM // self.SPLIT_D_FACTOR)
self.q_shape = gl.constexpr([self.SPLIT_M, self.HEAD_DIM])
self.k_shape = gl.constexpr([self.BLOCK_N, self.HEAD_DIM])
self.qk_shape = gl.constexpr([self.SPLIT_M, self.BLOCK_N])
self.v_shape = gl.constexpr([self.BLOCK_N, self.HEAD_DIM])
self.o_shape = gl.constexpr([self.SPLIT_M, self.HEAD_DIM])
qk_instr_shape = get_mma_instr_shape(self.qk_shape, gl.float32)
o_instr_shape = get_mma_instr_shape(self.o_shape, gl.float32)
self.qk_tmem_layout = gl.constexpr(TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1))
self.o_tmem_layout = gl.constexpr(TensorMemoryLayout((o_instr_shape[0], o_instr_shape[1]), col_stride=1))
self.p_tmem_layout = gl.constexpr(TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), col_stride=1))
o_splitn_tmem_layout: gl.constexpr = TensorMemoryLayout(
(o_instr_shape[0], o_instr_shape[1] // self.SPLIT_D_FACTOR), col_stride=1)
qk_tmem_ty: gl.constexpr = tensor_memory_descriptor_type(gl.float32, self.qk_shape, self.qk_tmem_layout,
self.qk_shape)
o_splitn_tmem_ty: gl.constexpr = tensor_memory_descriptor_type(
gl.float32,
[self.o_shape[0], self.o_shape[1] // self.SPLIT_D_FACTOR],
o_splitn_tmem_layout,
self.o_shape,
)
self.qk_layout = gl.constexpr(qk_tmem_ty.get_reg_layout(num_warps=self.num_warps,
instr_variant="32x32b_splitn"))
self.o_splitn_layout = gl.constexpr(o_splitn_tmem_ty.get_reg_layout(num_warps=self.num_warps))
self.alpha_2d_layout = gl.constexpr(gl.BlockedLayout([1, 1], [32, 1], [self.num_warps, 1], [0, 1]))
self.num_kv_buffers = gl.constexpr(NUM_KV_BUFFERS)
self.use_exp2_turnstile = gl.constexpr(USE_EXP2_TURNSTILE)
@gluon.jit
def get_program(self, pid_m, pid_n):
start_m = pid_m
off_hz = pid_n
off_z = off_hz // self.H
off_h = off_hz % self.H
offset_y = off_z * (self.N_CTX * self.H) + off_h * self.N_CTX
qo_offset_y = offset_y + start_m * self.BLOCK_M
return AttentionProgram(self, start_m, off_hz, offset_y, qo_offset_y)
@gluon.aggregate
class ProgramScheduler:
config: AttentionConfig
start_pid: gl.tensor
num_pid_n: gl.tensor
num_pid_in_group: gl.tensor
num_tiles: gl.tensor
@gluon.jit
def create(config):
start_pid = gl.program_id(0)
num_pid_m = gl.cdiv(config.N_CTX, config.BLOCK_M)
num_pid_n = config.Z * config.H
num_pid_in_group = num_pid_m * config.GROUP_SIZE_N
num_tiles = num_pid_m * num_pid_n
return ProgramScheduler(config, start_pid, num_pid_n, num_pid_in_group, num_tiles)
@gluon.jit
def get_program(self, tile_id):
group_id = tile_id // self.num_pid_in_group
first_pid_n = group_id * self.config.GROUP_SIZE_N
group_size_n = min(self.num_pid_n - first_pid_n, self.config.GROUP_SIZE_N)
pid_n = first_pid_n + (tile_id % group_size_n)
pid_m = (tile_id % self.num_pid_in_group) // group_size_n
return self.config.get_program(pid_m, pid_n)
@gluon.aggregate
class AttentionProgram:
config: AttentionConfig
start_m: gl.tensor
off_hz: gl.tensor
offset_y: gl.tensor
qo_offset_y: gl.tensor
@gluon.jit
def get_fused_loop_bounds(self, STAGE: gl.constexpr):
BLOCK_M: gl.constexpr = self.config.BLOCK_M
if STAGE == 1:
return 0, self.config.N_CTX
elif STAGE == 2:
return self.start_m * BLOCK_M, (self.start_m + 1) * BLOCK_M
elif STAGE == 3:
return 0, (self.start_m + 1) * BLOCK_M
else:
return 0, 0
@gluon.jit
def get_loop_bounds(self, STAGE: gl.constexpr):
BLOCK_M: gl.constexpr = self.config.BLOCK_M
if STAGE == 1:
lo, hi = 0, self.start_m * BLOCK_M
elif STAGE == 2:
lo, hi = self.start_m * BLOCK_M, (self.start_m + 1) * BLOCK_M
else:
lo, hi = 0, self.config.N_CTX
return lo, hi
# ===-----------------------------------------------------------------------===#
# _gluon_attn
# ===-----------------------------------------------------------------------===#
@gluon.jit
def _borrow_s_as_p(config, s_tmem):
p_tmem = s_tmem.slice(0, config.BLOCK_N // 2)
return p_tmem._reinterpret(config.dtype, config.qk_shape, config.p_tmem_layout)
@gluon.jit
def _borrow_s_as_alpha(config, s_tmem):
alpha_tmem = s_tmem.slice(config.BLOCK_N // 2, 1)
alpha_layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], col_stride=1)
return alpha_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], alpha_layout)
@gluon.jit
def _borrow_s_for_epilogue(config, s_tmem):
m_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 1, 1)
l_i_tmem = s_tmem.slice(config.BLOCK_N // 2 + 2, 1)
layout: gl.constexpr = TensorMemoryLayout([config.SPLIT_M, 1], col_stride=1)
m_i_tmem = m_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
l_i_tmem = l_i_tmem._reinterpret(gl.float32, [config.SPLIT_M, 1], layout)
return m_i_tmem, l_i_tmem
@gluon.constexpr_function
def _get_split_n_layout(layout: gl.constexpr, SPLIT_FACTOR: gl.constexpr = 2):
assert isinstance(layout, gl.DistributedLinearLayout), "split_n requires a distributed layout"
assert SPLIT_FACTOR == 1 or SPLIT_FACTOR == 2, "split_n requires a split factor of 1 or 2"
if SPLIT_FACTOR == 1:
return layout
else:
target = [0, layout.shape[1] // 2] # [0, 2^{m-1}]
last_reg_idx = len(layout.reg_bases) - 1
reg_last = layout.reg_bases[last_reg_idx]
if reg_last == target:
return layout
ret = copy.deepcopy(layout)
# Find [0, 2^{m-1}] across lists and swap it with last reg
for L in (ret.reg_bases, ret.lane_bases, ret.warp_bases, ret.block_bases):
for i, b in enumerate(L):
if b == target:
L[i], ret.reg_bases[last_reg_idx] = reg_last, target
return ret
assert False, f"split_n requires having a basis {target}. Got\n{layout}"
@gluon.jit
def _split_n(x, SPLIT_FACTOR: gl.constexpr = 2):
if SPLIT_FACTOR == 1:
return (x, )
else:
layout: gl.constexpr = _get_split_n_layout(x.type.layout)
x0, x1 = x.reshape([x.shape[0], 2, x.shape[1] // 2]).permute(0, 2, 1).split()
x0 = gl.convert_layout(x0, layout, assert_trivial=True)
x1 = gl.convert_layout(x1, layout, assert_trivial=True)
return _split_n(x0, SPLIT_FACTOR // 2) + _split_n(x1, SPLIT_FACTOR // 2)
@gluon.constexpr_function
def _get_join_n_layout(layout, SPLIT_FACTOR: gl.constexpr = 2):
assert isinstance(layout, gl.DistributedLinearLayout), "join_n requires a Linear layout"
shape = list(layout.shape)
regs = [[0, shape[1] * (1 << i)] for i in range(int(math.log2(SPLIT_FACTOR)))]
shape[1] *= SPLIT_FACTOR
return gl.DistributedLinearLayout(
layout.reg_bases + regs,
layout.lane_bases,
layout.warp_bases,
layout.block_bases,
shape,
)
@gluon.jit
def _join_n(xs):
if len(xs) == 1:
return xs[0]
else:
x0 = _join_n(xs[:len(xs) // 2])
x1 = _join_n(xs[len(xs) // 2:])
layout: gl.constexpr = _get_join_n_layout(x0.type.layout)
x = gl.join(x0, x1).permute(0, 2, 1).reshape([x0.shape[0], x0.shape[1] * 2])
return gl.convert_layout(x, layout, assert_trivial=True)
@gluon.jit
def _attn_fwd_load(config, chnls, descs, M, STAGE: gl.constexpr):
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
desc_q, desc_k, desc_v, desc_o = descs
q_producer = q_chnl.create_producer()
kv_producer = kv_chnl.create_producer()
scheduler = ProgramScheduler.create(config)
for pid in range(scheduler.start_pid, scheduler.num_tiles, config.NUM_SMS):
prog = scheduler.get_program(pid)
lo, hi = prog.get_fused_loop_bounds(STAGE)
q0_offset = prog.qo_offset_y + config.SPLIT_M * 0
q0_smem, q0_bar, q_producer = q_producer.acquire()
issue_async_tma_load(q0_smem, q0_bar, desc_q, q0_offset)
offsetkv_y = prog.offset_y + lo
k_smem, k_bar, kv_producer = kv_producer.acquire()
issue_async_tma_load(k_smem, k_bar, desc_k, offsetkv_y)
q1_offset = prog.qo_offset_y + config.SPLIT_M * 1
q1_smem, q1_bar, q_producer = q_producer.acquire()
issue_async_tma_load(q1_smem, q1_bar, desc_q, q1_offset)
v_smem, v_bar, kv_producer = kv_producer.acquire()
issue_async_tma_load(v_smem, v_bar, desc_v, offsetkv_y)
for start_n in range(lo + config.BLOCK_N, hi, config.BLOCK_N):
offsetkv_y = prog.offset_y + start_n
k_smem, k_bar, kv_producer = kv_producer.acquire()
issue_async_tma_load(k_smem, k_bar, desc_k, offsetkv_y)
v_smem, v_bar, kv_producer = kv_producer.acquire()
issue_async_tma_load(v_smem, v_bar, desc_v, offsetkv_y)
@gluon.jit
def _attn_fwd_mma(config, chnls, descs, M, STAGE: gl.constexpr):
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
desc_q, desc_k, desc_v, desc_o = descs
q_consumer = q_chnl.create_consumer()
kv_consumer = kv_chnl.create_consumer()
o_producer = o_chnl.create_producer()
s0_producer = s0_chnl.create_producer()
s1_producer = s1_chnl.create_producer()
scheduler = ProgramScheduler.create(config)
for pid in range(scheduler.start_pid, scheduler.num_tiles, config.NUM_SMS):
prog = scheduler.get_program(pid)
lo, hi = prog.get_fused_loop_bounds(STAGE)
num_mmas = (hi - lo) // config.BLOCK_N
q0_smem, q0_bar, q_consumer = q_consumer.acquire()
k_smem, k_bar, kv_consumer = kv_consumer.acquire()
s0_tmem, s0_bar, s0_producer = s0_producer.acquire()
tcgen05_mma(q0_smem, k_smem.permute((1, 0)), s0_tmem, use_acc=False, mbarriers=[s0_bar])
q1_smem, q1_bar, q_consumer = q_consumer.acquire()
s1_tmem, s1_bar, s1_producer = s1_producer.acquire()
tcgen05_mma(q1_smem, k_smem.permute((1, 0)), s1_tmem, use_acc=False, mbarriers=[s1_bar, k_bar])
v_smem, v_bar, kv_consumer = kv_consumer.acquire()
o0_tmem, o0_bar, o_producer = o_producer.acquire()
s0_tmem, s0_bar, s0_producer = s0_producer.acquire()
p0_tmem = _borrow_s_as_p(config, s0_tmem)
tcgen05_mma(p0_tmem, v_smem, o0_tmem, use_acc=False, mbarriers=[o0_bar])
o1_init = False
for _ in range(num_mmas - 1):
k_smem, k_bar, kv_consumer = kv_consumer.acquire()
tcgen05_mma(q0_smem, k_smem.permute((1, 0)), s0_tmem, use_acc=False, mbarriers=[s0_bar])
o1_tmem, o1_bar, o_producer = o_producer.acquire()
s1_tmem, s1_bar, s1_producer = s1_producer.acquire()
p1_tmem = _borrow_s_as_p(config, s1_tmem)
tcgen05_mma(p1_tmem, v_smem, o1_tmem, use_acc=o1_init, mbarriers=[o1_bar, v_bar])
o1_init = True
tcgen05_mma(q1_smem, k_smem.permute((1, 0)), s1_tmem, use_acc=False, mbarriers=[s1_bar, k_bar])
v_smem, v_bar, kv_consumer = kv_consumer.acquire()
o0_tmem, o0_bar, o_producer = o_producer.acquire()
s0_tmem, s0_bar, s0_producer = s0_producer.acquire()
p0_tmem = _borrow_s_as_p(config, s0_tmem)
tcgen05_mma(p0_tmem, v_smem, o0_tmem, mbarriers=[o0_bar])
tcgen05_commit(q0_bar)
tcgen05_commit(q1_bar)
o1_tmem, o1_bar, o_producer = o_producer.acquire()
s1_tmem, s1_bar, s1_producer = s1_producer.acquire()
p1_tmem = _borrow_s_as_p(config, s1_tmem)
tcgen05_mma(p1_tmem, v_smem, o1_tmem, use_acc=o1_init, mbarriers=[o1_bar, v_bar, s0_bar, s1_bar])
@gluon.jit
def _mask_scalar(qk, col_limit_right, s, i):
col_lim_right_s = col_limit_right - s
col_lim_right_cur = max(col_lim_right_s, 0)
mask = -1 << col_lim_right_cur
mask_i_bit = (mask & (1 << i)) == 0
return gl.where(mask_i_bit, qk, -float("inf"))
@gluon.jit
def _apply_causal_mask(qk, col_limit_right):
# Apply causal mask via a bitmask calculated for each block of 16 elements.
# This allows the efficient R2P (register to predicate) instruction to be used at the SASS level.
# Credit to Tri Dao,
# https://github.com/Dao-AILab/flash-attention/commit/bac1001e4f6caa09d70537495d6746a685a2fa78
#
# NOTE: We use map_elementiwse here in order to generate an interleaved sequence of instructions
# that processes one element of qk at a time. This improves ptxas's resulting SASS.
offs_n = gl.arange(0, qk.shape[1])[None, :]
s = offs_n & ~0xf
i = offs_n & 0xf
return gl.map_elementwise(_mask_scalar, qk, col_limit_right, s, i)
@gluon.jit
def _compute_and_store_exp2(config, qk, p_tmem):
SIZE: gl.constexpr = p_tmem.shape[1] // config.SPLIT_EXP_FACTOR
qks = _split_n(qk, config.SPLIT_EXP_FACTOR)
ps = ()
for i in gl.static_range(config.SPLIT_EXP_FACTOR):
p = gl.exp2(qks[i])
p_tmem.slice(i * SIZE, SIZE).store(p.to(config.dtype))
ps = ps + (p, )
return _join_n(ps)
@gluon.jit
def _subtiled_qk_load(config, s_tmem, use_tmem_red: gl.constexpr):
SIZE: gl.constexpr = s_tmem.shape[1] // config.SPLIT_QK_LOAD_FACTOR
qks = ()
if use_tmem_red:
red_total = None
for i in gl.static_range(config.SPLIT_QK_LOAD_FACTOR):
vals, reds = s_tmem.slice(i * SIZE, SIZE).load_max()
red_total = reds if red_total is None else gl.maximum(red_total, reds)
qks = qks + (vals, )
return _join_n(qks), red_total
else:
for i in gl.static_range(config.SPLIT_QK_LOAD_FACTOR):
qks = qks + (s_tmem.slice(i * SIZE, SIZE).load(), )
return _join_n(qks), None
@gluon.jit
def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
s_consumer, corr_producer, exp_turnstile, corr_bar, #
offs_m, m_i, l_i, STAGE: gl.constexpr, use_tmem_red: gl.constexpr):
lo, hi = prog.get_loop_bounds(STAGE)
for start_n in range(lo, hi, config.BLOCK_N):
s_tmem, s_bar, s_consumer = s_consumer.acquire()
qk, qk_max = _subtiled_qk_load(config, s_tmem, use_tmem_red)
if STAGE == 2:
col_limit_right = (offs_m - start_n + 1)[:, None]
qk = _apply_causal_mask(qk, col_limit_right)
if use_tmem_red:
qk_max = gl.convert_layout(qk_max, m_i.type.layout)
m_ij = gl.maximum(m_i, qk_max * config.qk_scale)
else:
m_ij = gl.maximum(m_i, gl.max(qk, 1) * config.qk_scale)
alpha = gl.exp2(m_i - m_ij)
alpha_tmem = _borrow_s_as_alpha(config, s_tmem)
alpha_tmem.store(gl.convert_layout(alpha.expand_dims(1), config.alpha_2d_layout))
mbarrier.arrive(corr_bar, count=1)
rowmax = float2.pack(-m_ij[:, None].broadcast_to(qk.shape), axis=1)
qk = float2.pack(qk, axis=1)
qk = float2.fma(qk, float2.full_like(qk, config.qk_scale), rowmax)
qk = float2.unpack(qk, axis=1)
# Force the softmax partitions to take turns in the EX2 section. This
# prevents contention for the EX2 unit and improves utilization.
if config.use_exp2_turnstile:
_, exp_bar, exp_turnstile = exp_turnstile.acquire()
# FIXME: When using FADD2 reductions, ptxas misbehaves and spills far
# below the register limit in the FADD2, FMUL2, EX2 section. Subtile by
# 4 to minimize the spilling.
p_tmem = _borrow_s_as_p(config, s_tmem)
p = _compute_and_store_exp2(config, qk, p_tmem)
mbarrier.arrive(s_bar, count=1)
_, corr_bar, corr_producer = corr_producer.acquire()
if config.use_exp2_turnstile:
mbarrier.arrive(exp_bar, count=1)
l_ij = float2.pack2(*_split_n(p)).sum(axis=1)
l_ij = Float2Tensor(gl.convert_layout(l_ij.value, l_i.value.type.layout, assert_trivial=True))
alpha = gl.convert_layout(alpha, l_i.value.type.layout, assert_trivial=True)
l_i = float2.fma(l_i, float2.pack2(alpha, alpha), l_ij)
m_i = m_ij
return m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile
@gluon.jit
def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr, #
s_chnl, corr_chnl, exp_turnstile, use_tmem_red: gl.constexpr):
qk_slice_dim1: gl.constexpr = gl.SliceLayout(1, config.qk_layout)
sum_layout: gl.constexpr = _get_split_n_layout(config.qk_layout)
s_consumer = s_chnl.create_consumer()
corr_producer = corr_chnl.create_producer()
_, corr_bar, corr_producer = corr_producer.acquire()
scheduler = ProgramScheduler.create(config)
for pid in range(scheduler.start_pid, scheduler.num_tiles, config.NUM_SMS):
prog = scheduler.get_program(pid)
offs_m = prog.start_m * config.BLOCK_M
offs_m += gl.arange(tile_id * config.SPLIT_M, (1 + tile_id) * config.SPLIT_M)
m_i = gl.full([config.SPLIT_M], -float("inf"), gl.float32, qk_slice_dim1)
# Accumulate into 2 row-sums so the reduction can be performed with FADD2.
l_i = gl.full([config.SPLIT_M], 0.0, gl.float32, gl.SliceLayout(1, sum_layout))
l_i = float2.pack2(l_i, l_i)
if STAGE & 1:
m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile = _softmax_inner_loop( #
tile_id, config, prog, s_consumer, corr_producer, exp_turnstile, corr_bar, #
offs_m, m_i, l_i, STAGE=4 - STAGE, use_tmem_red=use_tmem_red)
if STAGE & 2:
m_i, l_i, corr_bar, s_consumer, corr_producer, exp_turnstile = _softmax_inner_loop( #
tile_id, config, prog, s_consumer, corr_producer, exp_turnstile, corr_bar, #
offs_m, m_i, l_i, STAGE=2, use_tmem_red=use_tmem_red)
l_i0, l_i1 = float2.unpack2(l_i)
l_i = l_i0 + l_i1
s_tmem, s_bar, s_consumer = s_consumer.acquire()
m_i_tmem, l_i_tmem = _borrow_s_for_epilogue(config, s_tmem)
m_i_tmem.store(gl.convert_layout(m_i.expand_dims(1), config.alpha_2d_layout))
l_i_tmem.store(gl.convert_layout(l_i.expand_dims(1), config.alpha_2d_layout))
mbarrier.arrive(corr_bar, count=1)
_, corr_bar, corr_producer = corr_producer.acquire()
mbarrier.arrive(s_bar, count=1)
@gluon.jit
def _attn_fwd_softmax0(config, chnls, descs, M, STAGE: gl.constexpr, use_tmem_red: gl.constexpr):
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
desc_q, desc_k, desc_v, desc_o = descs
_softmax_tile(0, config, M, desc_o, STAGE, s0_chnl, c0_chnl, exp_turnstile.create_producer(), use_tmem_red)
@gluon.jit
def _attn_fwd_softmax1(config, chnls, descs, M, STAGE: gl.constexpr, use_tmem_red: gl.constexpr):
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
desc_q, desc_k, desc_v, desc_o = descs
_softmax_tile(1, config, M, desc_o, STAGE, s1_chnl, c1_chnl, exp_turnstile.create_consumer(), use_tmem_red)
@gluon.jit
def _attn_fwd_epilogue(config, chnls, descs, M, STAGE: gl.constexpr):
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
desc_q, desc_k, desc_v, desc_o = descs
epi_consumer = epi_chnl.create_consumer()
scheduler = ProgramScheduler.create(config)
for pid in range(scheduler.start_pid, scheduler.num_tiles, config.NUM_SMS):
prog = scheduler.get_program(pid)
o0_smem, o0_bar, epi_consumer = epi_consumer.acquire()
tma.async_copy_shared_to_global(desc_o, [prog.qo_offset_y + config.SPLIT_M * 0, 0], o0_smem)
o1_smem, o1_bar, epi_consumer = epi_consumer.acquire()
tma.async_copy_shared_to_global(desc_o, [prog.qo_offset_y + config.SPLIT_M * 1, 0], o1_smem)
tma.store_wait(1)
mbarrier.arrive(o0_bar, count=1)
tma.store_wait(0)
mbarrier.arrive(o1_bar, count=1)
@gluon.jit
def _attn_fwd_correction_rescale(config, s_tmem, corr_consumer, o_consumer):
alpha_layout: gl.constexpr = gl.SliceLayout(1, config.o_splitn_layout)
o_tmem, o_bar, o_consumer = o_consumer.acquire()
_, corr_bar, corr_consumer = corr_consumer.acquire()
alpha = _borrow_s_as_alpha(config, s_tmem).load(config.alpha_2d_layout)
mbarrier.arrive(corr_bar, count=1)
alpha = gl.convert_layout(alpha.reshape([config.SPLIT_M]), alpha_layout)
alpha = float2.pack(alpha[:, None].broadcast_to(config.o_shape[0], config.SPLIT_D), axis=1)
for i in gl.static_range(config.SPLIT_D_FACTOR):
o_ref = o_tmem.slice(i * config.SPLIT_D, config.SPLIT_D)
o = float2.pack(o_ref.load(config.o_splitn_layout), axis=1)
o = o * alpha
o_ref.store(float2.unpack(o, axis=1))
mbarrier.arrive(o_bar, count=1)
return corr_consumer, o_consumer
@gluon.jit
def _attn_fwd_correction_epilogue(config, prog, s_tmem, M, corr_consumer, epi_producer, o_consumer):
alpha_layout: gl.constexpr = gl.SliceLayout(1, config.o_splitn_layout)
_, corr_bar, corr_consumer = corr_consumer.acquire()
m_i_tmem, l_i_tmem = _borrow_s_for_epilogue(config, s_tmem)
m_i = m_i_tmem.load(config.alpha_2d_layout).reshape([config.SPLIT_M])
m_i = gl.convert_layout(m_i, alpha_layout)
l_i = l_i_tmem.load(config.alpha_2d_layout).reshape([config.SPLIT_M])
l_i = gl.convert_layout(l_i, alpha_layout)
mbarrier.arrive(corr_bar, count=1)
o_smem, epi_bar, epi_producer = epi_producer.acquire()
o_tmem, o_bar, o_consumer = o_consumer.acquire()
# Shared memory subtile size is limited by the swizzle byte size.
contigDimSize: gl.constexpr = o_smem.type.layout.swizzle_byte_width * 8 // o_smem.type.element_ty.primitive_bitwidth
if o_smem.type.shape[1] // config.SPLIT_D_FACTOR >= contigDimSize:
SPLIT_N_FACTOR: gl.constexpr = config.SPLIT_D_FACTOR
else:
SPLIT_N_FACTOR: gl.constexpr = 1
gl.static_assert(o_smem.type.shape[1] // SPLIT_N_FACTOR >= contigDimSize,
"Block shape is too small for the swizzle byte size in NVMMA Shared Layout")
SPLIT_N: gl.constexpr = o_smem.type.shape[1] // SPLIT_N_FACTOR
scale = float2.pack((1 / l_i)[:, None].broadcast_to(config.o_shape[0], SPLIT_N), axis=1)
for i in gl.static_range(SPLIT_N_FACTOR):
o_ref = o_tmem.slice(i * SPLIT_N, SPLIT_N)
o = float2.pack(o_ref.load(config.o_splitn_layout), axis=1)
o = o * scale
o_smem.slice(i * SPLIT_N, SPLIT_N, dim=1).store(float2.unpack(o, axis=1).to(config.dtype))
fence_async_shared()
mbarrier.arrive(epi_bar, count=1)
mbarrier.arrive(o_bar, count=1)
m_i += gl.log2(l_i)
coalesced: gl.constexpr = gl.BlockedLayout([1], [32], [config.num_warps], [0])
offs_m = prog.start_m * config.BLOCK_M
offs_m += gl.arange(0 * config.SPLIT_M, 1 * config.SPLIT_M, coalesced)
m_ptrs = M + prog.off_hz * config.N_CTX + offs_m
gl.store(m_ptrs, gl.convert_layout(m_i, coalesced))
return corr_consumer, epi_producer, o_consumer
@gluon.jit
def _attn_fwd_correction(config, chnls, descs, M, STAGE: gl.constexpr):
q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile = chnls
s0_tmem = s0_chnl.mem.index(0)
s1_tmem = s1_chnl.mem.index(0)
corr0_consumer = c0_chnl.create_consumer()
corr1_consumer = c1_chnl.create_consumer()
o_consumer = o_chnl.create_consumer()
epi_producer = epi_chnl.create_producer()
scheduler = ProgramScheduler.create(config)
for pid in range(scheduler.start_pid, scheduler.num_tiles, config.NUM_SMS):
prog = scheduler.get_program(pid)
lo, hi = prog.get_fused_loop_bounds(STAGE)
num_corrections = (hi - lo) // config.BLOCK_N
_, corr0_bar, corr0_consumer = corr0_consumer.acquire()
mbarrier.arrive(corr0_bar, count=1)
_, corr1_bar, corr1_consumer = corr1_consumer.acquire()
mbarrier.arrive(corr1_bar, count=1)
for i in range(num_corrections - 1):
corr0_consumer, o_consumer = _attn_fwd_correction_rescale(config, s0_tmem, corr0_consumer, o_consumer)
corr1_consumer, o_consumer = _attn_fwd_correction_rescale(config, s1_tmem, corr1_consumer, o_consumer)
corr0_consumer, epi_producer, o_consumer = _attn_fwd_correction_epilogue( #
config, prog, s0_tmem, M, corr0_consumer, epi_producer, o_consumer)
corr1_consumer, epi_producer, o_consumer = _attn_fwd_correction_epilogue( #
config, prog, s1_tmem, M, corr1_consumer, epi_producer, o_consumer)
def attention_repr(specialization):
name = "gluon_attention"
# Up to 150 TFLOPS faster for fp8!
if specialization.constants["dtype"] == gl.float8e5:
name = "cutlass_" + name
return name
@gluon.jit(do_not_specialize=["Z", "H", "N_CTX"], repr=attention_repr)
def attention_kernel( #
sm_scale, M, Z, H, N_CTX, desc_q, desc_k, desc_v, desc_o, #
BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr, HEAD_DIM: gl.constexpr, #
GROUP_SIZE_N: gl.constexpr, NUM_SMS: gl.constexpr, STAGE: gl.constexpr, SPLIT_EXP_FACTOR: gl.constexpr, #
dtype: gl.constexpr, num_warps: gl.constexpr, use_tmem_red: gl.constexpr, NUM_KV_BUFFERS: gl.constexpr,
USE_EXP2_TURNSTILE: gl.constexpr):
qk_scale = sm_scale * 1.44269504
config = AttentionConfig(qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE_N, NUM_SMS, STAGE,
SPLIT_EXP_FACTOR, #
dtype, num_warps, NUM_KV_BUFFERS, USE_EXP2_TURNSTILE)
q_chnl = get_desc_channel(desc_q, num_buffers=2)
kv_chnl = get_desc_channel(desc_k, num_buffers=config.num_kv_buffers)
o_chnl = TensorMemoryChannel.alloc(config.o_shape, gl.float32, config.o_tmem_layout, num_buffers=2)
epi_chnl = SharedMemoryChannel.alloc(config.o_shape, config.dtype, gl.constexpr(desc_o.layout), num_buffers=2)
s0_chnl = TensorMemoryChannel.alloc(config.qk_shape, gl.float32, config.qk_tmem_layout, num_buffers=1)
s1_chnl = TensorMemoryChannel.alloc(config.qk_shape, gl.float32, config.qk_tmem_layout, num_buffers=1)
c0_chnl = SharedMemoryChannel.alloc([1], gl.int8, gl.constexpr(mbarrier.MBarrierLayout()), num_buffers=1)
c1_chnl = SharedMemoryChannel.alloc([1], gl.int8, gl.constexpr(mbarrier.MBarrierLayout()), num_buffers=1)
exp_turnstile = SharedMemoryChannel.alloc([1], gl.int8, gl.constexpr(mbarrier.MBarrierLayout()), num_buffers=1)
chnls = (q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile)
descs = (desc_q, desc_k, desc_v, desc_o)
gl.warp_specialize([
(_attn_fwd_correction, (config, chnls, descs, M, STAGE)),
(_attn_fwd_softmax0, (config, chnls, descs, M, STAGE, use_tmem_red)),
(_attn_fwd_softmax1, (config, chnls, descs, M, STAGE, use_tmem_red)),
(_attn_fwd_mma, (config, chnls, descs, M, STAGE)),
(_attn_fwd_load, (config, chnls, descs, M, STAGE)),
(_attn_fwd_epilogue, (config, chnls, descs, M, STAGE)),
], [4, 4, 1, 1, 1], [192, 192, 24, 24, 24])
q_chnl.release()
kv_chnl.release()
o_chnl.release()
epi_chnl.release()
s0_chnl.release()
s1_chnl.release()
c0_chnl.release()
c1_chnl.release()
exp_turnstile.release()
# ===-----------------------------------------------------------------------===#
# Entry Point
# ===-----------------------------------------------------------------------===#
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def is_blackwell():
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
def is_blackwell_ultra():
return is_cuda() and torch.cuda.get_device_capability()[0:2] == (10, 3)
@dataclass(frozen=True, slots=True)
class KernelConfig:
BLOCK_M: int = 256
BLOCK_N: int = 128
GROUP_SIZE_N: int | None = None
SPLIT_EXP_FACTOR: int | None = None
NUM_WARPS: int = 4
MAXNREG: int = 128
OCCUPANCY: int = 1
USE_TMEM_RED: bool = False
NUM_KV_BUFFERS: int | None = None
USE_EXP2_TURNSTILE: bool | None = None
def _default_split_exp_factor(head_dim: int) -> int:
return max(1, 256 // head_dim)
def _default_num_kv_buffers(head_dim: int, dtype: torch.dtype) -> int:
is_fp16 = dtype in [torch.float16, torch.bfloat16]
if is_fp16:
return 3 if head_dim == 128 else 6
return 4 if head_dim == 128 else 8
def select_kernel_config(
head_dim: int,
n_ctx: int,
dtype: torch.dtype,
causal: bool,
use_tmem_red: bool,
override: KernelConfig | None = None,
) -> KernelConfig:
is_fp8 = dtype == torch.float8_e5m2
is_bf16 = dtype == torch.bfloat16
is_bwu = is_blackwell_ultra()
block_m = 256
block_n = 128
group_size_n = 1
split_exp_factor = _default_split_exp_factor(head_dim)
num_warps = 4
maxnreg = 128
occupancy = 1
use_selected_tmem_red = (use_tmem_red or (is_bwu and not causal)) and not causal
num_kv_buffers = _default_num_kv_buffers(head_dim, dtype)
use_exp2_turnstile = head_dim == 64
if causal:
group_size_n = 8 if head_dim == 64 or n_ctx <= 2048 else 4
if head_dim == 128:
split_exp_factor = 4
if not causal and is_bf16 and n_ctx <= 2048:
group_size_n = 4
elif not causal and head_dim == 64 and use_selected_tmem_red:
split_exp_factor = 1
if n_ctx <= 1024:
num_kv_buffers = 2
elif n_ctx >= 8192:
maxnreg = 112
elif causal and head_dim == 64:
num_kv_buffers = 2
if n_ctx <= 1024:
split_exp_factor = 2
else:
use_exp2_turnstile = False
if is_fp8:
if causal and head_dim == 64:
group_size_n = 8 if n_ctx <= 2048 else 4
split_exp_factor = 4 if n_ctx <= 2048 else 2
maxnreg = 112 if n_ctx >= 4096 else 128
use_selected_tmem_red = False
num_kv_buffers = 2
use_exp2_turnstile = n_ctx <= 1024
elif causal and head_dim == 128:
group_size_n = 8 if n_ctx <= 2048 else 4
split_exp_factor = 2 if n_ctx <= 2048 else 8
maxnreg = 128
use_selected_tmem_red = False
num_kv_buffers = 4
use_exp2_turnstile = False
elif not causal and head_dim == 64:
group_size_n = 1
split_exp_factor = 2
maxnreg = 128
use_selected_tmem_red = is_bwu
num_kv_buffers = 2 if n_ctx <= 1024 else 8
use_exp2_turnstile = True
elif not causal and head_dim == 128:
group_size_n = 1
split_exp_factor = 4 if n_ctx <= 2048 else 8
maxnreg = 128
use_selected_tmem_red = is_bwu
num_kv_buffers = 4
use_exp2_turnstile = False
else:
group_size_n = 4 if causal else 1
split_exp_factor = _default_split_exp_factor(head_dim)
use_selected_tmem_red = use_tmem_red and not causal
config = KernelConfig(
BLOCK_M=block_m,
BLOCK_N=block_n,
GROUP_SIZE_N=group_size_n,
SPLIT_EXP_FACTOR=split_exp_factor,
NUM_WARPS=num_warps,
MAXNREG=maxnreg,
OCCUPANCY=occupancy,
USE_TMEM_RED=use_selected_tmem_red,
NUM_KV_BUFFERS=num_kv_buffers,
USE_EXP2_TURNSTILE=use_exp2_turnstile,
)
if override is None:
return config
values = {field.name: getattr(override, field.name) for field in fields(KernelConfig)}
values = {name: getattr(config, name) if value is None else value for name, value in values.items()}
return KernelConfig(**values)
def torch_dtype_to_triton(dtype):
if dtype == torch.float8_e5m2:
return gl.float8e5
return getattr(gl, str(dtype).split('.')[1])
def make_tensor_desc(x, shape, strides, block_shape):
layout = gl.NVMMASharedLayout.get_default_for(block_shape, torch_dtype_to_triton(x.dtype))
return TensorDescriptor(x, shape=shape, strides=strides, block_shape=block_shape, layout=layout)
def attention_forward(q, k, v, causal, sm_scale, o=None, M=None, *, use_tmem_red=False, p: KernelConfig | None = None):
if isinstance(o, bool) and M is None and use_tmem_red is False:
use_tmem_red = o
o = None
HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
HEAD_DIM_V = v.shape[-1]
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
stage = 3 if causal else 1
p = select_kernel_config(HEAD_DIM_K, q.shape[2], q.dtype, causal, use_tmem_red, override=p)
if o is None:
o = torch.empty_like(q)
if M is None:
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
y_dim = q.shape[0] * q.shape[1] * q.shape[2]
# The kernel will split BLOCK_M into two subtiles.
BLOCK_M = p.BLOCK_M
BLOCK_N = p.BLOCK_N
SPLIT_M = BLOCK_M // 2
GROUP_SIZE_N = p.GROUP_SIZE_N
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count * p.OCCUPANCY
desc_q = make_tensor_desc(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[SPLIT_M, HEAD_DIM_K])
desc_v = make_tensor_desc(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_N, HEAD_DIM_K])
desc_k = make_tensor_desc(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[BLOCK_N, HEAD_DIM_K])
desc_o = make_tensor_desc(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[SPLIT_M, HEAD_DIM_K])
num_pid_m = triton.cdiv(q.shape[2], BLOCK_M)
num_pid_n = q.shape[0] * q.shape[1]
grid = min(NUM_SMS, num_pid_m * num_pid_n)
attention_kernel[(grid, )](
sm_scale, M, q.shape[0], q.shape[1], q.shape[2], #
desc_q, desc_k, desc_v, desc_o, #
BLOCK_M, BLOCK_N, HEAD_DIM_K, GROUP_SIZE_N, NUM_SMS, #
SPLIT_EXP_FACTOR=p.SPLIT_EXP_FACTOR, STAGE=stage, dtype=torch_dtype_to_triton(q.dtype), #
num_warps=p.NUM_WARPS, maxnreg=p.MAXNREG, use_tmem_red=p.USE_TMEM_RED, NUM_KV_BUFFERS=p.NUM_KV_BUFFERS,
USE_EXP2_TURNSTILE=p.USE_EXP2_TURNSTILE)
return o, M
# ===-----------------------------------------------------------------------===#
# Unit Tests
# ===-----------------------------------------------------------------------===#
@pytest.mark.parametrize("Z", [1, 4])
@pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("N_CTX", [1024, 2048, 4096, 8192])
@pytest.mark.parametrize("HEAD_DIM", [64, 128])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_tmem_red", [False, True] if is_blackwell_ultra() else [False])
@pytest.mark.skipif(not is_blackwell(), reason="Gluon attention is only supported on Blackwell GPUs")
def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype, use_tmem_red, profile=False):
device = "cuda"
def alloc_fn(size: int, alignment: int, stream):
return torch.empty(size, dtype=torch.int8, device=device)
triton.set_allocator(alloc_fn)
if use_tmem_red and not is_blackwell_ultra():
pytest.skip("TMEM reduction is only supported on Blackwell Ultra GPUs")
torch.manual_seed(42)
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=device).normal_(mean=0.0, std=0.5).requires_grad_())
sm_scale = 0.5
ref_out = torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale, is_causal=causal)
tri_out, _ = attention_forward(q, k, v, causal, sm_scale, use_tmem_red=use_tmem_red)
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
# ===-----------------------------------------------------------------------===#
# Benchmarking
# ===-----------------------------------------------------------------------===#
BATCH = [4]
N_HEADS = [32]
HEAD_DIM = [64, 128]
causal = [False, True]
providers = ["triton-fp16", "triton-fp8"]
N_CTX = [2**i for i in range(10, 17)]
use_tmem_reds = [False, True] if is_blackwell_ultra() else [False]
bench_configs = []
for Z, H, D, is_causal, use_tmem_red in itertools.product(BATCH, N_HEADS, HEAD_DIM, causal, use_tmem_reds):
config = triton.testing.Benchmark(
x_names=["N_CTX"],
x_vals=N_CTX,
line_arg="provider",
line_vals=providers,
line_names=providers,
styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")],
ylabel="TFLOPS",
plot_name=f"Attention Z={Z} H={H} D={D} causal={is_causal} use_tmem_red={use_tmem_red}",
args={
"Z": Z,
"H": H,
"HEAD_DIM": D,
"causal": is_causal,
"use_tmem_red": use_tmem_red,
},
)
bench_configs.append(config)
@triton.testing.perf_report(bench_configs)
def bench(Z, H, N_CTX, HEAD_DIM, causal, use_tmem_red, provider):
provider, dtype = provider.split("-")
if dtype == "fp16":
dtype = torch.float16
elif dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp8":
dtype = torch.float8_e5m2
else:
raise ValueError(f"Unsupported dtype: {dtype}")
device = "cuda"
torch.manual_seed(42)
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), device=device).normal_(mean=0.0, std=0.5).requires_grad_()).to(dtype)
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), device=device).normal_(mean=0.0, std=0.5).requires_grad_()).to(dtype)
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), device=device).normal_(mean=0.0, std=0.5).requires_grad_()).to(dtype)
sm_scale = 1.3
o = torch.empty_like(q)
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.CUDNN_ATTENTION]):
if provider == "triton":
fn = lambda: attention_forward(q, k, v, causal, sm_scale, o, M, use_tmem_red=use_tmem_red)
elif provider == "cudnn":
fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, scale=sm_scale, is_causal=causal)
else:
raise ValueError(f"Unsupported provider: {provider}")
ms = triton.testing.do_bench_cudagraph(fn)
flops_per_matmul = 2.0 * Z * H * N_CTX * N_CTX * HEAD_DIM
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
return total_flops * 1e-12 / (ms * 1e-3)
if __name__ == "__main__":
bench.run(save_path=".", print_data=True)