Conv Dgrad
This example can be found at python/examples/gluon/02-conv-dgrad.py.
import importlib.util
import sys
from pathlib import Path
import pytest
import torch
import triton
import triton.language as tl
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor, TensorDescriptorIm2Col
from triton.experimental.gluon.language.nvidia.hopper import tma, mbarrier
from triton.experimental.gluon.language.nvidia.blackwell import (
TensorMemoryLayout,
allocate_tensor_memory,
tensor_memory_descriptor,
tcgen05_mma,
tcgen05_commit,
)
def _load_conv_common():
module_name = "triton_examples_gluon_conv_common"
module = sys.modules.get(module_name)
if module is not None:
return module
module_path = Path(__file__).with_name("02-conv-common.py")
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None:
raise ImportError(f"Unable to load shared conv helpers from {module_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
_conv_common = _load_conv_common()
# ===-----------------------------------------------------------------------===#
# Utilities
# ===-----------------------------------------------------------------------===#
Counter = _conv_common.Counter
GL_GEMM_DTYPE = _conv_common.GL_GEMM_DTYPE
PersistentTileScheduler = _conv_common.PersistentTileScheduler
TORCH_GEMM_DTYPE = _conv_common.TORCH_GEMM_DTYPE
ensure_tma_compatible_strides = _conv_common.ensure_tma_compatible_strides
init_mbarrier_ring = _conv_common.init_mbarrier_ring
invalidate_mbarrier_ring = _conv_common.invalidate_mbarrier_ring
is_blackwell = _conv_common.is_blackwell
maybe_pad_ci_for_tma = _conv_common.maybe_pad_channel_dims_for_tma
normalize_2d = _conv_common.normalize_2d
# ===-----------------------------------------------------------------------===#
# Dgrad GEMM mapping
# ===-----------------------------------------------------------------------===#
#
# Dgrad as forward conv on grad_Y with rotated weight:
# grad_X[M, Ci] = im2col(grad_Y)[M, R_eff*S_eff*Co] @ W_rot[R_eff*S_eff*Co, Ci]^T
#
# For stride > 1, the host decomposes dgrad into up to stride_h * stride_w
# subproblems. Each subproblem fixes (sub_a, sub_b, r0, s0, R_eff, S_eff,
# offset_a, offset_b), builds a grad_Y im2col descriptor, and launches the
# persistent kernel once.
# Per subproblem launch, the logical tile space is:
# cdiv(M_GEMM, BLOCK_M) * cdiv(Ci, BLOCK_N) * active_split_k
# where M_GEMM = N * H_sub * W_sub and
# total_k_iters = R_eff * S_eff * cdiv(Co, BLOCK_K).
#
# The epilogue scatters results back to the full output tensor at
# h = sub_a + c_out_y * stride_h
# w = sub_b + c_out_x * stride_w
#
# If active_split_k > 1, the kernel stores fp32 partials to a workspace and a
# separate reduction kernel accumulates them into the final output. The launch
# uses a persistent scheduler and runs only `min(num_sms, logical_tiles)` CTAs.
@gluon.aggregate
class DgradConfig:
N: gl.tensor
Co: gl.tensor
Ci: gl.tensor
R_eff: gl.tensor
S_eff: gl.tensor
H_sub: gl.tensor
W_sub: gl.tensor
pad_h: gl.tensor
pad_w: gl.tensor
output_stride_n: gl.tensor
output_stride_h: gl.tensor
output_stride_w: gl.tensor
M_GEMM: gl.tensor
sub_a: gl.tensor
sub_b: gl.tensor
conv_stride_h: gl.tensor
conv_stride_w: gl.tensor
r0: gl.tensor
s0: gl.tensor
S_orig: gl.tensor
H_full: gl.tensor
W_full: gl.tensor
BLOCK_M: gl.constexpr
BLOCK_N: gl.constexpr
BLOCK_K: gl.constexpr
GROUP_SIZE_M: gl.constexpr
SPLIT_K: gl.constexpr
num_buffers: gl.constexpr
num_warps: gl.constexpr
@gluon.jit
def get_num_output_tiles(self):
return gl.cdiv(self.M_GEMM, self.BLOCK_M) * gl.cdiv(self.Ci, self.BLOCK_N)
@gluon.jit
def get_num_k_iterations(self):
return self.R_eff * self.S_eff * gl.cdiv(self.Co, self.BLOCK_K)
@gluon.jit
def get_active_split_k(self):
total_k_iters = self.get_num_k_iterations()
k_iters_per_split = gl.cdiv(total_k_iters, self.SPLIT_K)
return gl.cdiv(total_k_iters, k_iters_per_split)
@gluon.jit
def get_num_tiles(self):
return self.get_num_output_tiles() * self.get_active_split_k()
@gluon.jit
def get_program(self, pid):
active_split_k = self.get_active_split_k()
split_k_idx = pid % active_split_k
tile_id = pid // active_split_k
num_pid_m = gl.cdiv(self.M_GEMM, self.BLOCK_M)
num_pid_n = gl.cdiv(self.Ci, self.BLOCK_N)
num_pid_in_group = self.GROUP_SIZE_M * num_pid_n
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * self.GROUP_SIZE_M
group_size_m = gl.minimum(num_pid_m - first_pid_m, self.GROUP_SIZE_M)
pid_m = first_pid_m + (tile_id % group_size_m)
pid_n = (tile_id % num_pid_in_group) // group_size_m
total_k_iters = self.get_num_k_iterations()
k_iters_per_split = gl.cdiv(total_k_iters, active_split_k)
k_start = split_k_idx * k_iters_per_split
remaining_k_iters = total_k_iters - k_start
zero = gl.to_tensor(0)
k_iters_this_split = gl.where(
remaining_k_iters > 0,
gl.minimum(k_iters_per_split, remaining_k_iters),
zero,
)
return DgradProgram(self, pid_m, pid_n, split_k_idx, k_start, k_iters_this_split)
@gluon.aggregate
class DgradProgram:
config: DgradConfig
pid_m: gl.tensor
pid_n: gl.tensor
split_k_idx: gl.tensor
k_start: gl.tensor
k_iters_this_split: gl.tensor
@gluon.jit
def get_m_offsets(self):
offs_m = self.pid_m * self.config.BLOCK_M
config = self.config
out_x = offs_m % config.W_sub
out_y = (offs_m // config.W_sub) % config.H_sub
batch_id = (offs_m // config.W_sub) // config.H_sub
return batch_id, out_y, out_x
@gluon.jit
def get_ci_offset(self):
return self.pid_n * self.config.BLOCK_N
@gluon.jit
def get_k_iteration(self, local_k):
k_iter = self.k_start + local_k
num_rs = self.config.R_eff * self.config.S_eff
iter_co = k_iter // num_rs
remain_rs = k_iter % num_rs
iter_s = remain_rs % self.config.S_eff
iter_r = remain_rs // self.config.S_eff
return iter_co, iter_r, iter_s
@gluon.jit
def get_weight_k_offset(self, local_k):
iter_co, iter_r, iter_s = self.get_k_iteration(local_k)
actual_r = self.config.r0 + iter_r * self.config.conv_stride_h
actual_s = self.config.s0 + iter_s * self.config.conv_stride_w
k_offset = (actual_r * self.config.S_orig + actual_s) * self.config.Co + iter_co * self.config.BLOCK_K
return iter_co, iter_r, iter_s, k_offset
@gluon.aggregate
class PartitionArgs:
config: DgradConfig
grad_y_desc: tma.tensor_descriptor_im2col
weight_desc: tma.tensor_descriptor
output_ptr: gl.tensor
store_split_k_partials: gl.constexpr
a_bufs: gl.shared_memory_descriptor
b_bufs: gl.shared_memory_descriptor
load_empty_bars: gl.shared_memory_descriptor
load_ready_bars: gl.shared_memory_descriptor
acc_bufs: tensor_memory_descriptor
acc_empty_bars: gl.shared_memory_descriptor
acc_ready_bars: gl.shared_memory_descriptor
# ===-----------------------------------------------------------------------===#
# Warp-Specialized Partitions
# ===-----------------------------------------------------------------------===#
@gluon.jit
def load_partition(p):
"""Load partition: iterate over the persistent dgrad work items assigned to this CTA."""
config = p.config
BLOCK_K: gl.constexpr = config.BLOCK_K
empty_bars = p.load_empty_bars
ready_bars = p.load_ready_bars
state = Counter.create(1, empty_bars.shape[0])
scheduler = PersistentTileScheduler.initialize(config.get_num_tiles())
for idx in range(scheduler.get_num_tiles()):
prog = config.get_program(scheduler.get_tile_id(idx))
batch_id, out_y, out_x = prog.get_m_offsets()
ci_offset = prog.get_ci_offset()
for local_k in range(prog.k_iters_this_split):
iter_co, iter_r, iter_s, weight_k_offset = prog.get_weight_k_offset(local_k)
ready_bar = ready_bars.index(state.index)
mbarrier.wait(empty_bars.index(state.index), state.phase)
mbarrier.expect(ready_bar, p.grad_y_desc.block_type.nbytes + p.weight_desc.block_type.nbytes)
tma.async_load_im2col(
p.grad_y_desc,
[
batch_id,
out_y - config.pad_h,
out_x - config.pad_w,
iter_co * BLOCK_K,
],
[iter_r.to(tl.int16), iter_s.to(tl.int16)],
ready_bar,
p.a_bufs.index(state.index),
)
tma.async_load(
p.weight_desc,
[ci_offset, weight_k_offset],
ready_bar,
p.b_bufs.index(state.index),
)
state = state.next()
@gluon.jit
def mma_partition(p):
"""MMA partition: accumulate all split-K dgrad work items assigned to this CTA."""
config = p.config
load_state = Counter.create(0, p.load_empty_bars.shape[0])
acc_state = Counter.create(1, p.acc_empty_bars.shape[0])
scheduler = PersistentTileScheduler.initialize(config.get_num_tiles())
for idx in range(scheduler.get_num_tiles()):
prog = config.get_program(scheduler.get_tile_id(idx))
mbarrier.wait(p.acc_empty_bars.index(acc_state.index), acc_state.phase)
acc_buf = p.acc_bufs.index(acc_state.index)
use_acc = False
for _local_k in range(prog.k_iters_this_split):
mbarrier.wait(p.load_ready_bars.index(load_state.index), load_state.phase)
tcgen05_mma(
p.a_bufs.index(load_state.index),
p.b_bufs.index(load_state.index).permute((1, 0)),
acc_buf,
use_acc=use_acc,
)
tcgen05_commit(p.load_empty_bars.index(load_state.index))
load_state = load_state.next()
use_acc = True
tcgen05_commit(p.acc_ready_bars.index(acc_state.index))
acc_state = acc_state.next()
@gluon.jit
def epilogue_partition(p):
"""Epilogue partition: store the persistent dgrad work items assigned to this CTA."""
config = p.config
BLOCK_M: gl.constexpr = config.BLOCK_M
BLOCK_N: gl.constexpr = config.BLOCK_N
M_GEMM = config.M_GEMM
N_GEMM = config.Ci
acc_state = Counter.create(0, p.acc_empty_bars.shape[0])
scheduler = PersistentTileScheduler.initialize(config.get_num_tiles())
for idx in range(scheduler.get_num_tiles()):
prog = config.get_program(scheduler.get_tile_id(idx))
mbarrier.wait(p.acc_ready_bars.index(acc_state.index), acc_state.phase)
acc = p.acc_bufs.index(acc_state.index).load()
mbarrier.arrive(p.acc_empty_bars.index(acc_state.index), count=1)
acc_state = acc_state.next()
offs_m = prog.pid_m * BLOCK_M + gl.arange(0, BLOCK_M)
offs_n = prog.get_ci_offset() + gl.arange(0, BLOCK_N)
c_out_x = offs_m % config.W_sub
c_out_y = (offs_m // config.W_sub) % config.H_sub
c_batch = (offs_m // config.W_sub) // config.H_sub
h = config.sub_a + c_out_y * config.conv_stride_h
w = config.sub_b + c_out_x * config.conv_stride_w
c_offsets = (c_batch[:, None] * config.output_stride_n + h[:, None] * config.output_stride_h +
w[:, None] * config.output_stride_w + offs_n[None, :])
c_mask = ((offs_m[:, None] < M_GEMM) & (offs_n[None, :] < N_GEMM) & (h[:, None] < config.H_full) &
(w[:, None] < config.W_full))
result = gl.convert_layout(acc, gl.CoalescedLayout())
if p.store_split_k_partials:
split_batch = prog.split_k_idx * config.N + c_batch
split_offsets = (split_batch[:, None] * config.output_stride_n + h[:, None] * config.output_stride_h +
w[:, None] * config.output_stride_w + offs_n[None, :])
gl.store(p.output_ptr + split_offsets, result, mask=c_mask)
else:
gl.store(p.output_ptr + c_offsets, result.to(GL_GEMM_DTYPE), mask=c_mask)
invalidate_mbarrier_ring(p.load_empty_bars)
invalidate_mbarrier_ring(p.load_ready_bars)
invalidate_mbarrier_ring(p.acc_empty_bars)
invalidate_mbarrier_ring(p.acc_ready_bars)
# ===-----------------------------------------------------------------------===#
# Kernel Entry Point
# ===-----------------------------------------------------------------------===#
@gluon.jit(do_not_specialize=[
"N",
"S_orig",
"H_sub",
"W_sub",
"H_full",
"W_full",
"conv_stride_h",
"conv_stride_w",
"sub_a",
"sub_b",
"r0",
"s0",
"R_eff",
"S_eff",
"pad_h",
"pad_w",
])
def conv2d_dgrad_kernel(
grad_y_desc,
weight_desc,
output,
N,
Co,
Ci,
S_orig,
H_sub,
W_sub,
H_full,
W_full,
output_stride_n,
output_stride_h,
output_stride_w,
conv_stride_h,
conv_stride_w,
sub_a,
sub_b,
r0,
s0,
R_eff,
S_eff,
pad_h,
pad_w,
STORE_SPLIT_K_PARTIALS: gl.constexpr,
BLOCK_M: gl.constexpr,
BLOCK_N: gl.constexpr,
BLOCK_K: gl.constexpr,
GROUP_SIZE_M: gl.constexpr,
SPLIT_K: gl.constexpr,
num_buffers: gl.constexpr,
num_acc_buffers: gl.constexpr,
num_warps: gl.constexpr,
):
"""Warp-specialized dgrad kernel.
Logical tile space = cdiv(M_sub, BLOCK_M) * cdiv(Ci, BLOCK_N), optionally
multiplied by split-K. Sub-problem parameters (sub_a/b, r0/s0, R_eff/S_eff,
pad) are per-launch constants.
"""
M_GEMM = N * H_sub * W_sub
config = DgradConfig(
N,
Co,
Ci,
R_eff,
S_eff,
gl.to_tensor(H_sub),
gl.to_tensor(W_sub),
pad_h,
pad_w,
gl.to_tensor(output_stride_n),
gl.to_tensor(output_stride_h),
gl.to_tensor(output_stride_w),
M_GEMM,
sub_a,
sub_b,
gl.to_tensor(conv_stride_h),
gl.to_tensor(conv_stride_w),
r0,
s0,
S_orig,
H_full,
W_full,
BLOCK_M,
BLOCK_N,
BLOCK_K,
GROUP_SIZE_M,
SPLIT_K,
num_buffers,
num_warps,
)
a_smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], GL_GEMM_DTYPE)
b_smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK_N, BLOCK_K], GL_GEMM_DTYPE)
a_bufs = gl.allocate_shared_memory(GL_GEMM_DTYPE, [num_buffers, BLOCK_M, BLOCK_K], a_smem_layout)
b_bufs = gl.allocate_shared_memory(GL_GEMM_DTYPE, [num_buffers, BLOCK_N, BLOCK_K], b_smem_layout)
load_empty_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
load_ready_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
init_mbarrier_ring(load_empty_bars)
init_mbarrier_ring(load_ready_bars)
TMEM_BLOCK_M: gl.constexpr = 64 if BLOCK_M == 64 else 128
tmem_layout: gl.constexpr = TensorMemoryLayout(block=(TMEM_BLOCK_M, BLOCK_N), col_stride=1)
acc_bufs = allocate_tensor_memory(gl.float32, [num_acc_buffers, BLOCK_M, BLOCK_N], tmem_layout)
acc_empty_bars = gl.allocate_shared_memory(gl.int64, [num_acc_buffers, 1], mbarrier.MBarrierLayout())
acc_ready_bars = gl.allocate_shared_memory(gl.int64, [num_acc_buffers, 1], mbarrier.MBarrierLayout())
init_mbarrier_ring(acc_empty_bars)
init_mbarrier_ring(acc_ready_bars)
p = PartitionArgs(
config,
grad_y_desc,
weight_desc,
output,
STORE_SPLIT_K_PARTIALS,
a_bufs,
b_bufs,
load_empty_bars,
load_ready_bars,
acc_bufs,
acc_empty_bars,
acc_ready_bars,
)
gl.warp_specialize([
(epilogue_partition, (p, )),
(mma_partition, (p, )),
(load_partition, (p, )),
], [1, 1], [24, 24])
# ===-----------------------------------------------------------------------===#
# Autotuning
# ===-----------------------------------------------------------------------===#
def conv2d_dgrad_get_configs():
return [
triton.Config(
{
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"GROUP_SIZE_M": group_size_m,
"SPLIT_K": split_k,
"num_buffers": num_buffers,
"num_acc_buffers": num_acc_buffers,
},
num_warps=num_warps,
)
for block_m in (64, 128)
for block_n in (64, 128, 256)
for block_k in (64, )
for group_size_m in (4, )
for split_k in (1, 2, 4, 8)
for num_buffers in (3, 4, 5)
for num_acc_buffers in (2, )
for num_warps in (4, )
]
# ===-----------------------------------------------------------------------===#
# Host-Side Entry Point
# ===-----------------------------------------------------------------------===#
def _make_dgrad_subproblem_specs(R, S, stride_h, stride_w, pad_h, pad_w):
p_h_prime = R - 1 - pad_h
p_w_prime = S - 1 - pad_w
subproblem_specs = []
for a in range(stride_h):
for b in range(stride_w):
r0 = ((p_h_prime - a) % stride_h + stride_h) % stride_h
s0 = ((p_w_prime - b) % stride_w + stride_w) % stride_w
R_eff = (R - r0 + stride_h - 1) // stride_h
S_eff = (S - s0 + stride_w - 1) // stride_w
if R_eff <= 0 or S_eff <= 0:
continue
offset_a = (a + r0 - p_h_prime) // stride_h
offset_b = (b + s0 - p_w_prime) // stride_w
subproblem_specs.append((a, b, r0, s0, R_eff, S_eff, offset_a, offset_b))
return subproblem_specs
def _prepare_dgrad_inputs(grad_output_nhwc, weight_nhwc, H_in, W_in, stride, padding):
"""Validate inputs, pad channels, and compute sub-problem decomposition."""
if grad_output_nhwc.dtype != TORCH_GEMM_DTYPE or weight_nhwc.dtype != TORCH_GEMM_DTYPE:
raise ValueError(
f"conv2d_dgrad expects bfloat16 grad-output and weight tensors, got {grad_output_nhwc.dtype} and {weight_nhwc.dtype}"
)
stride_h, stride_w = normalize_2d(stride, "stride")
pad_h, pad_w = normalize_2d(padding, "padding")
if stride_h <= 0 or stride_w <= 0:
raise ValueError(f"stride must be positive, got {(stride_h, stride_w)}")
if pad_h < 0 or pad_w < 0:
raise ValueError(f"padding must be non-negative, got {(pad_h, pad_w)}")
N, out_h, out_w, Co = grad_output_nhwc.shape
Co_w, R, S, Ci = weight_nhwc.shape
if Co != Co_w:
raise ValueError(f"Channel dimension mismatch: grad-output has {Co}, weight has {Co_w}")
expected_out_h = (H_in + 2 * pad_h - R) // stride_h + 1
expected_out_w = (W_in + 2 * pad_w - S) // stride_w + 1
if out_h != expected_out_h or out_w != expected_out_w:
raise ValueError("Grad-output shape mismatch: expected "
f"({N}, {expected_out_h}, {expected_out_w}, {Co}) from input/filter geometry, got "
f"({N}, {out_h}, {out_w}, {Co}).")
if out_h <= 0 or out_w <= 0:
raise ValueError("Invalid convolution geometry for dgrad")
grad_output_nhwc = maybe_pad_ci_for_tma(grad_output_nhwc)
grad_output_nhwc = ensure_tma_compatible_strides(grad_output_nhwc)
Co_padded = grad_output_nhwc.shape[-1]
if Co_padded != Co:
w_padded = weight_nhwc.new_zeros((Co_padded, R, S, Ci))
w_padded[:Co] = weight_nhwc
weight_nhwc = w_padded.contiguous()
Co = Co_padded
W_rot = weight_nhwc.flip(1, 2).permute(3, 1, 2, 0).contiguous() # (Ci, R, S, Co)
W_rot_flat = W_rot.reshape(Ci, R * S * Co)
H_sub = (H_in + stride_h - 1) // stride_h
W_sub = (W_in + stride_w - 1) // stride_w
subproblem_specs = _make_dgrad_subproblem_specs(R, S, stride_h, stride_w, pad_h, pad_w)
if not subproblem_specs:
raise ValueError("No valid dgrad sub-problems were generated")
return (
grad_output_nhwc,
W_rot_flat,
N,
Co,
Ci,
S,
out_h,
out_w,
H_in,
W_in,
H_sub,
W_sub,
stride_h,
stride_w,
subproblem_specs,
)
def _make_dgrad_weight_descriptor(W_rot_flat, weight_block_shape):
weight_layout = gl.NVMMASharedLayout.get_default_for(weight_block_shape, GL_GEMM_DTYPE)
weight_desc = TensorDescriptor.from_tensor(W_rot_flat, weight_block_shape, weight_layout)
return weight_desc
def _make_dgrad_grad_y_descriptor(grad_output_nhwc, H_sub, W_sub, out_h, out_w, offset_a, offset_b, input_block_shape):
lower_h = offset_a
lower_w = offset_b
upper_h = H_sub + offset_a - out_h
upper_w = W_sub + offset_b - out_w
input_layout = gl.NVMMASharedLayout.get_default_for(input_block_shape, GL_GEMM_DTYPE)
return TensorDescriptorIm2Col(
base=grad_output_nhwc,
shape=list(grad_output_nhwc.shape),
strides=list(grad_output_nhwc.stride()),
block_shape=input_block_shape,
layout=input_layout,
padding="zero",
element_strides=[1, 1, 1, 1],
pixel_box_lower_corner=[lower_h, lower_w],
pixel_box_upper_corner=[upper_h, upper_w],
)
def _make_grid(num_sms, M_GEMM, Ci, Co, R_eff, S_eff):
def grid(meta):
total_mn_tiles = triton.cdiv(M_GEMM, meta["BLOCK_M"]) * triton.cdiv(Ci, meta["BLOCK_N"])
total_k_iters = R_eff * S_eff * triton.cdiv(Co, meta["BLOCK_K"])
k_iters_per_split = triton.cdiv(total_k_iters, meta["SPLIT_K"])
active_split_k = triton.cdiv(total_k_iters, k_iters_per_split)
total_tiles = total_mn_tiles * active_split_k
return (min(num_sms, total_tiles), )
return grid
def _get_active_split_k(total_k_iters, SPLIT_K):
k_iters_per_split = triton.cdiv(total_k_iters, SPLIT_K)
return triton.cdiv(total_k_iters, k_iters_per_split)
def _get_dgrad_subproblem_active_split_k(Co, R_eff, S_eff, BLOCK_K, SPLIT_K):
total_k_iters = R_eff * S_eff * triton.cdiv(Co, BLOCK_K)
return _get_active_split_k(total_k_iters, SPLIT_K)
def _get_max_active_split_k(Co, subproblem_specs, BLOCK_K, SPLIT_K):
return max(
_get_dgrad_subproblem_active_split_k(Co, R_eff, S_eff, BLOCK_K, SPLIT_K)
for _, _, _, _, R_eff, S_eff, _, _ in subproblem_specs)
def _get_safe_dgrad_max_active_split_k(Co, subproblem_specs, N, H_in, W_in, Ci, kernel_meta):
"""Return max active split-K across subproblems, or raise if workspace would be too large to index safely."""
max_active_split_k = _get_max_active_split_k(Co, subproblem_specs, kernel_meta["BLOCK_K"], kernel_meta["SPLIT_K"])
if max_active_split_k > 1:
# Workspace shape: (active_split_k * N, H_in, W_in, Ci); indexed in kernels by batch/row offsets.
# Very large workspaces can exceed the addressing range supported by the generated code.
workspace_elems = max_active_split_k * N * H_in * W_in * Ci
if workspace_elems > (2**31 - 1):
raise ValueError("dgrad split-K workspace exceeds safe indexing range: "
f"active_split_k={max_active_split_k}, N={N}, H_in={H_in}, W_in={W_in}, Ci={Ci}")
return max_active_split_k
def _allocate_dgrad_split_k_workspace(device, active_split_k, N, H_in, W_in, Ci):
return torch.zeros((active_split_k * N, H_in, W_in, Ci), device=device, dtype=torch.float32)
_dgrad_autotune_cache = {}
def _make_dgrad_autotune_key(
device,
num_sms,
N,
Co,
Ci,
S,
out_h,
out_w,
H_in,
W_in,
H_sub,
W_sub,
stride_h,
stride_w,
subproblem_specs,
):
return (
torch.cuda.get_device_capability(device),
num_sms,
N,
Co,
Ci,
S,
out_h,
out_w,
H_in,
W_in,
H_sub,
W_sub,
stride_h,
stride_w,
tuple(subproblem_specs),
)
@triton.jit
def reduce_dgrad_split_k_partials_kernel(
partial_ptr,
output_ptr,
partial_stride_n,
partial_stride_h,
partial_stride_w,
output_stride_n,
output_stride_h,
output_stride_w,
N,
H,
W,
Ci,
ACTIVE_SPLIT_K: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
M = N * H * W
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
w_idx = offs_m % W
h_idx = (offs_m // W) % H
batch_idx = offs_m // (H * W)
mask = (offs_m[:, None] < M) & (offs_n[None, :] < Ci)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for split_k_idx in range(ACTIVE_SPLIT_K):
partial_batch = split_k_idx * N + batch_idx
partial_offsets = (partial_batch[:, None] * partial_stride_n + h_idx[:, None] * partial_stride_h +
w_idx[:, None] * partial_stride_w + offs_n[None, :])
acc += tl.load(partial_ptr + partial_offsets, mask=mask, other=0.0)
output_offsets = (batch_idx[:, None] * output_stride_n + h_idx[:, None] * output_stride_h +
w_idx[:, None] * output_stride_w + offs_n[None, :])
tl.store(output_ptr + output_offsets, acc, mask=mask)
def _reduce_dgrad_split_k_partials(partials, output, N, H, W, Ci, active_split_k):
BLOCK_M = 64
BLOCK_N = 64
grid = (triton.cdiv(N * H * W, BLOCK_M), triton.cdiv(Ci, BLOCK_N))
reduce_dgrad_split_k_partials_kernel[grid](
partials,
output,
partials.stride(0),
partials.stride(1),
partials.stride(2),
output.stride(0),
output.stride(1),
output.stride(2),
N,
H,
W,
Ci,
ACTIVE_SPLIT_K=active_split_k,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=4,
)
def _launch_dgrad_subproblems(
kernel,
num_sms,
*,
grad_output_nhwc,
W_rot_flat,
output,
N,
Co,
Ci,
S,
out_h,
out_w,
H_in,
W_in,
H_sub,
W_sub,
stride_h,
stride_w,
subproblem_specs,
weight_block_shape,
input_block_shape,
kernel_meta=None,
):
if kernel_meta is None:
kernel_meta = {}
kernel_meta.setdefault("STORE_SPLIT_K_PARTIALS", False)
M_GEMM = N * H_sub * W_sub
weight_desc = _make_dgrad_weight_descriptor(W_rot_flat, weight_block_shape)
for a, b, r0_val, s0_val, R_eff_val, S_eff_val, offset_a, offset_b in subproblem_specs:
grad_y_desc = _make_dgrad_grad_y_descriptor(
grad_output_nhwc,
H_sub,
W_sub,
out_h,
out_w,
offset_a,
offset_b,
input_block_shape,
)
kernel[_make_grid(num_sms, M_GEMM, Ci, Co, R_eff_val, S_eff_val)](
grad_y_desc=grad_y_desc,
weight_desc=weight_desc,
output=output,
N=N,
Co=Co,
Ci=Ci,
S_orig=S,
H_sub=H_sub,
W_sub=W_sub,
H_full=H_in,
W_full=W_in,
output_stride_n=output.stride(0),
output_stride_h=output.stride(1),
output_stride_w=output.stride(2),
conv_stride_h=stride_h,
conv_stride_w=stride_w,
sub_a=a,
sub_b=b,
r0=r0_val,
s0=s0_val,
R_eff=R_eff_val,
S_eff=S_eff_val,
pad_h=-offset_a,
pad_w=-offset_b,
**kernel_meta,
)
def _allocate_dgrad_output(device, N, H_in, W_in, Ci, split_k=1):
if split_k == 1:
return torch.empty((N, H_in, W_in, Ci), device=device, dtype=TORCH_GEMM_DTYPE)
return torch.zeros((N, H_in, W_in, Ci), device=device, dtype=torch.float32)
def _finalize_dgrad_output(output):
return output.to(TORCH_GEMM_DTYPE)
def _make_dgrad_runner(
grad_output_nhwc,
W_rot_flat,
*,
N,
Co,
Ci,
S,
out_h,
out_w,
H_in,
W_in,
H_sub,
W_sub,
stride_h,
stride_w,
subproblem_specs,
num_sms,
kernel_meta,
):
max_active_split_k = _get_safe_dgrad_max_active_split_k(Co, subproblem_specs, N, H_in, W_in, Ci, kernel_meta)
uses_split_k_workspace = max_active_split_k > 1
output = _allocate_dgrad_output(
grad_output_nhwc.device,
N,
H_in,
W_in,
Ci,
split_k=max_active_split_k if uses_split_k_workspace else 1,
)
launch_output = output
if uses_split_k_workspace:
launch_output = _allocate_dgrad_split_k_workspace(grad_output_nhwc.device, max_active_split_k, N, H_in, W_in,
Ci)
def run():
_launch_dgrad_subproblems(
conv2d_dgrad_kernel,
num_sms,
grad_output_nhwc=grad_output_nhwc,
W_rot_flat=W_rot_flat,
output=launch_output,
N=N,
Co=Co,
Ci=Ci,
S=S,
out_h=out_h,
out_w=out_w,
H_in=H_in,
W_in=W_in,
H_sub=H_sub,
W_sub=W_sub,
stride_h=stride_h,
stride_w=stride_w,
subproblem_specs=subproblem_specs,
weight_block_shape=[kernel_meta["BLOCK_N"], kernel_meta["BLOCK_K"]],
input_block_shape=[kernel_meta["BLOCK_M"], kernel_meta["BLOCK_K"]],
kernel_meta={
**kernel_meta,
"STORE_SPLIT_K_PARTIALS": uses_split_k_workspace,
},
)
if uses_split_k_workspace:
_reduce_dgrad_split_k_partials(launch_output, output, N, H_in, W_in, Ci, max_active_split_k)
return run, output
def _benchmark_dgrad_config(
grad_output_nhwc,
W_rot_flat,
*,
N,
Co,
Ci,
S,
out_h,
out_w,
H_in,
W_in,
H_sub,
W_sub,
stride_h,
stride_w,
subproblem_specs,
num_sms,
kernel_meta,
):
try:
run, _ = _make_dgrad_runner(
grad_output_nhwc,
W_rot_flat,
N=N,
Co=Co,
Ci=Ci,
S=S,
out_h=out_h,
out_w=out_w,
H_in=H_in,
W_in=W_in,
H_sub=H_sub,
W_sub=W_sub,
stride_h=stride_h,
stride_w=stride_w,
subproblem_specs=subproblem_specs,
num_sms=num_sms,
kernel_meta=kernel_meta,
)
run()
torch.cuda.synchronize()
return triton.testing.do_bench(run)
except Exception:
return float("inf")
def _select_dgrad_kernel_meta(
grad_output_nhwc,
W_rot_flat,
*,
N,
Co,
Ci,
S,
out_h,
out_w,
H_in,
W_in,
H_sub,
W_sub,
stride_h,
stride_w,
subproblem_specs,
num_sms,
):
cache_key = _make_dgrad_autotune_key(
grad_output_nhwc.device,
num_sms,
N,
Co,
Ci,
S,
out_h,
out_w,
H_in,
W_in,
H_sub,
W_sub,
stride_h,
stride_w,
subproblem_specs,
)
cached = _dgrad_autotune_cache.get(cache_key)
if cached is not None:
return dict(cached)
best_ms = float("inf")
best_kernel_meta = None
for config in conv2d_dgrad_get_configs():
kernel_meta = config.all_kwargs()
ms = _benchmark_dgrad_config(
grad_output_nhwc,
W_rot_flat,
N=N,
Co=Co,
Ci=Ci,
S=S,
out_h=out_h,
out_w=out_w,
H_in=H_in,
W_in=W_in,
H_sub=H_sub,
W_sub=W_sub,
stride_h=stride_h,
stride_w=stride_w,
subproblem_specs=subproblem_specs,
num_sms=num_sms,
kernel_meta=kernel_meta,
)
if ms < best_ms:
best_ms = ms
best_kernel_meta = dict(kernel_meta)
if best_kernel_meta is None:
raise RuntimeError("Failed to autotune conv2d_dgrad: no valid kernel configurations.")
_dgrad_autotune_cache[cache_key] = dict(best_kernel_meta)
return dict(best_kernel_meta)
def conv2d_dgrad(grad_output_nhwc, weight_nhwc, H_in, W_in, stride=1, padding=0):
"""Production dgrad entrypoint.
Selects the best kernel configuration with host-side autotuning, then runs
deterministic two-pass split-K when reduction is needed.
"""
(grad_output_nhwc, W_rot_flat, N, Co, Ci, S,
out_h, out_w, H_in, W_in, H_sub, W_sub,
stride_h, stride_w, subproblem_specs) = \
_prepare_dgrad_inputs(grad_output_nhwc, weight_nhwc, H_in, W_in, stride, padding)
device = grad_output_nhwc.device
num_sms = torch.cuda.get_device_properties(device).multi_processor_count
kernel_meta = _select_dgrad_kernel_meta(
grad_output_nhwc,
W_rot_flat,
N=N,
Co=Co,
Ci=Ci,
S=S,
out_h=out_h,
out_w=out_w,
H_in=H_in,
W_in=W_in,
H_sub=H_sub,
W_sub=W_sub,
stride_h=stride_h,
stride_w=stride_w,
subproblem_specs=subproblem_specs,
num_sms=num_sms,
)
run, output = _make_dgrad_runner(
grad_output_nhwc,
W_rot_flat,
N=N,
Co=Co,
Ci=Ci,
S=S,
out_h=out_h,
out_w=out_w,
H_in=H_in,
W_in=W_in,
H_sub=H_sub,
W_sub=W_sub,
stride_h=stride_h,
stride_w=stride_w,
subproblem_specs=subproblem_specs,
num_sms=num_sms,
kernel_meta=kernel_meta,
)
run()
return _finalize_dgrad_output(output)
def _make_dgrad_fixed_kernel_meta(SPLIT_K, num_buffers, num_warps):
# Keep the fixed path on a tile shape that is also covered by autotune configs.
return {
"BLOCK_M": 128,
"BLOCK_N": 256,
"BLOCK_K": 64,
"GROUP_SIZE_M": 4,
"SPLIT_K": SPLIT_K,
"num_buffers": num_buffers,
"num_acc_buffers": 2,
"num_warps": num_warps,
}
def conv2d_dgrad_fixed(grad_output_nhwc, weight_nhwc, H_in, W_in, stride=1, padding=0, num_buffers=2, num_warps=4,
SPLIT_K=1):
"""Fixed-config dgrad entrypoint used for CI and debugging.
Runs the kernel with a fixed supported tile shape instead of autotuning,
while still using deterministic two-pass split-K when reduction is needed.
"""
(grad_output_nhwc, W_rot_flat, N, Co, Ci, S,
out_h, out_w, H_in, W_in, H_sub, W_sub,
stride_h, stride_w, subproblem_specs) = \
_prepare_dgrad_inputs(grad_output_nhwc, weight_nhwc, H_in, W_in, stride, padding)
device = grad_output_nhwc.device
num_sms = torch.cuda.get_device_properties(device).multi_processor_count
kernel_meta = _make_dgrad_fixed_kernel_meta(SPLIT_K, num_buffers, num_warps)
run, output = _make_dgrad_runner(
grad_output_nhwc,
W_rot_flat,
N=N,
Co=Co,
Ci=Ci,
S=S,
out_h=out_h,
out_w=out_w,
H_in=H_in,
W_in=W_in,
H_sub=H_sub,
W_sub=W_sub,
stride_h=stride_h,
stride_w=stride_w,
subproblem_specs=subproblem_specs,
num_sms=num_sms,
kernel_meta=kernel_meta,
)
run()
return _finalize_dgrad_output(output)
# ===-----------------------------------------------------------------------===#
# Unit Tests
# ===-----------------------------------------------------------------------===#
def _assert_dgrad_correct(dgrad_fn, N, Ci, H, W, Co, R, S, stride, padding, **kwargs):
"""Run dgrad and compare against PyTorch autograd reference."""
torch.manual_seed(0)
stride_h, stride_w = normalize_2d(stride, "stride")
pad_h, pad_w = normalize_2d(padding, "padding")
out_h = (H + 2 * pad_h - R) // stride_h + 1
out_w = (W + 2 * pad_w - S) // stride_w + 1
grad_out_nchw = torch.randn((N, Co, out_h, out_w), device="cuda", dtype=TORCH_GEMM_DTYPE)
grad_out_nhwc = grad_out_nchw.permute(0, 2, 3, 1).contiguous()
w_nchw = torch.randn((Co, Ci, R, S), device="cuda", dtype=TORCH_GEMM_DTYPE)
w_nhwc = w_nchw.permute(0, 2, 3, 1).contiguous()
triton_dgrad = dgrad_fn(grad_out_nhwc, w_nhwc, H, W, stride=stride, padding=padding, **kwargs)
ref_dgrad = torch.ops.aten.convolution_backward(
grad_out_nchw,
torch.randn((N, Ci, H, W), device="cuda", dtype=TORCH_GEMM_DTYPE),
w_nchw,
bias_sizes=None,
stride=[stride_h, stride_w],
padding=[pad_h, pad_w],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
output_mask=[True, False, False],
)[0]
ref_dgrad_nhwc = ref_dgrad.permute(0, 2, 3, 1).contiguous()
torch.testing.assert_close(triton_dgrad, ref_dgrad_nhwc, atol=1e-2, rtol=1e-2)
DGRAD_SHAPE_PARAMS = [
*[(N, Ci, 64, 64, Co, R, S, stride, padding)
for N in (1, 128)
for Ci, Co in ((384, 384), (128, 128))
for R, S in ((2, 2), (3, 3))
for stride in (1, 2)
for padding in (0, 1)],
(16, 5, 32, 32, 96, 3, 3, 1, 1),
(16, 512, 2, 2, 768, 2, 2, 2, 0),
(16, 96, 1, 8, 128, 1, 2, (1, 2), (0, 0)),
(16, 128, 1, 4, 192, 1, 2, (1, 2), (0, 0)),
(16, 160, 1, 2, 256, 1, 2, (1, 2), (0, 0)),
]
@pytest.mark.parametrize(
"dgrad_fn,N,Ci,H,W,Co,R,S,stride,padding",
[(conv2d_dgrad_fixed, *shape) for shape in DGRAD_SHAPE_PARAMS],
)
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell GPU (SM 10.x)")
def test_op(dgrad_fn, N, Ci, H, W, Co, R, S, stride, padding):
_assert_dgrad_correct(dgrad_fn, N, Ci, H, W, Co, R, S, stride, padding)
# ===-----------------------------------------------------------------------===#
# Benchmarking
# ===-----------------------------------------------------------------------===#
BATCH = [128]
CHANNELS = [(384, 384)]
SPATIAL = [(64, 64)]
FILTER = [(3, 3)]
STRIDE = [1]
PADDING = [1]
def _make_bench_inputs(N, H, W, Ci, Co, R, S, stride_val, pad_val):
torch.manual_seed(0)
out_h = (H + 2 * pad_val - R) // stride_val + 1
out_w = (W + 2 * pad_val - S) // stride_val + 1
x_nchw = torch.randn((N, Ci, H, W), device="cuda", dtype=TORCH_GEMM_DTYPE)
grad_out_nchw = torch.randn((N, Co, out_h, out_w), device="cuda", dtype=TORCH_GEMM_DTYPE)
grad_out_nhwc = grad_out_nchw.permute(0, 2, 3, 1).contiguous()
w_nchw = torch.randn((Co, Ci, R, S), device="cuda", dtype=TORCH_GEMM_DTYPE)
w_nhwc = w_nchw.permute(0, 2, 3, 1).contiguous()
return x_nchw, grad_out_nchw, grad_out_nhwc, w_nchw, w_nhwc
def _benchmark_tflops(fn, *, N, H, W, Ci, Co, R, S, stride_val, pad_val):
ms = triton.testing.do_bench(fn)
out_h = (H + 2 * pad_val - R) // stride_val + 1
out_w = (W + 2 * pad_val - S) // stride_val + 1
flops = 2.0 * N * out_h * out_w * Co * Ci * R * S
return flops * 1e-12 / (ms * 1e-3)
bench_configs = []
for N, (Ci, Co), (H, W), (R, S), stride_val, pad_val in [(N, ch, sp, f, s, p)
for N in BATCH
for ch in CHANNELS
for sp in SPATIAL
for f in FILTER
for s in STRIDE
for p in PADDING]:
bench_configs.append(
triton.testing.Benchmark(
x_names=["kernel"],
x_vals=["autotuned"],
line_arg="provider",
line_vals=["gluon", "torch"],
line_names=["Gluon (autotuned)", "PyTorch"],
styles=[("green", "-"), ("blue", "-")],
ylabel="TFLOPS",
plot_name=f"Dgrad N={N} Ci={Ci} Co={Co} H={H} W={W} R={R} S={S} stride={stride_val} pad={pad_val}",
args={
"N": N,
"H": H,
"W": W,
"Ci": Ci,
"Co": Co,
"R": R,
"S": S,
"stride_val": stride_val,
"pad_val": pad_val,
},
))
@triton.testing.perf_report(bench_configs)
def bench(N, H, W, Ci, Co, R, S, stride_val, pad_val, kernel, provider):
x_nchw, grad_out_nchw, grad_out_nhwc, w_nchw, w_nhwc = \
_make_bench_inputs(N, H, W, Ci, Co, R, S, stride_val, pad_val)
if provider == "gluon":
fn = lambda: conv2d_dgrad(grad_out_nhwc, w_nhwc, H, W, stride=stride_val, padding=pad_val)
elif provider == "torch":
fn = lambda: torch.ops.aten.convolution_backward(
grad_out_nchw,
x_nchw,
w_nchw,
bias_sizes=None,
stride=[stride_val, stride_val],
padding=[pad_val, pad_val],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=1,
output_mask=[True, False, False],
)
else:
raise ValueError(f"Unsupported provider: {provider}")
return _benchmark_tflops(
fn,
N=N,
H=H,
W=W,
Ci=Ci,
Co=Co,
R=R,
S=S,
stride_val=stride_val,
pad_val=pad_val,
)
if __name__ == "__main__":
bench.run(save_path=".", print_data=True)