Conv Wgrad

This example can be found at python/examples/gluon/02-conv-wgrad.py.

import importlib.util
import sys
from pathlib import Path

import pytest
import torch

import triton
import triton.language as tl

from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor, TensorDescriptorIm2Col
from triton.experimental.gluon.language.nvidia.hopper import tma, mbarrier
from triton.experimental.gluon.language.nvidia.blackwell import (
    TensorMemoryLayout,
    allocate_tensor_memory,
    tensor_memory_descriptor,
    tcgen05_mma,
    tcgen05_commit,
)


def _load_conv_common():
    module_name = "triton_examples_gluon_conv_common"
    module = sys.modules.get(module_name)
    if module is not None:
        return module

    module_path = Path(__file__).with_name("02-conv-common.py")
    spec = importlib.util.spec_from_file_location(module_name, module_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Unable to load shared conv helpers from {module_path}")

    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


_conv_common = _load_conv_common()

# ===-----------------------------------------------------------------------===#
# Utilities
# ===-----------------------------------------------------------------------===#

Counter = _conv_common.Counter
GL_GEMM_DTYPE = _conv_common.GL_GEMM_DTYPE
PersistentTileScheduler = _conv_common.PersistentTileScheduler
TORCH_GEMM_DTYPE = _conv_common.TORCH_GEMM_DTYPE
init_mbarrier_ring = _conv_common.init_mbarrier_ring
invalidate_mbarrier_ring = _conv_common.invalidate_mbarrier_ring
is_blackwell = _conv_common.is_blackwell
maybe_pad_ci_for_tma = _conv_common.maybe_pad_channel_dims_for_tma
normalize_2d = _conv_common.normalize_2d

# ===-----------------------------------------------------------------------===#
# Wgrad GEMM mapping
# ===-----------------------------------------------------------------------===#
#
# grad_W[Co, R*S*Ci] = grad_out[M, Co]^T  @  im2col(input)[M, R*S*Ci]
#
# where M = N * out_h * out_w   (spatial positions — reduction dimension)
#
# MMA tiling:
#   BLOCK_M = tile over Co           (rows of grad_weight)
#   BLOCK_N = tile over Ci per (r,s) (cols of grad_weight)
#   BLOCK_K = tile over spatial       (reduction)
#
# Logical tile space: cdiv(Co, BLOCK_M) * R * S * cdiv(Ci, BLOCK_N), optionally
# multiplied by split-K. The launch uses a persistent scheduler and runs only
# `min(num_sms, logical_tiles)` CTAs.
#
# Loads per K iteration:
#   A = grad_out tile: TMA tiled on (M_spatial, Co),
#       block [BLOCK_K, BLOCK_M] — permuted to [M, K] in kernel.
#   B = im2col(input) tile: TMA im2col on [N,H,W,Ci], block [BLOCK_K, BLOCK_N]
#       Already [K, N], no kernel permute.
#
# MMA: acc[BLOCK_M, BLOCK_N] += A.permute(1,0) @ B

# ===-----------------------------------------------------------------------===#
# Wgrad Configuration
# ===-----------------------------------------------------------------------===#


@gluon.aggregate
class WgradConfig:
    N: gl.tensor
    Ci: gl.tensor
    Co: gl.tensor
    R: gl.tensor
    S: gl.tensor
    out_h: gl.tensor
    out_w: gl.tensor
    stride_h: gl.tensor
    stride_w: gl.tensor
    pad_h: gl.tensor
    pad_w: gl.tensor
    K_GEMM: gl.tensor
    M_spatial: gl.tensor

    BLOCK_M: gl.constexpr
    BLOCK_N: gl.constexpr
    BLOCK_K: gl.constexpr
    SPLIT_K: gl.constexpr
    num_buffers: gl.constexpr
    num_warps: gl.constexpr

    @gluon.jit
    def get_num_output_tiles(self):
        co_num_blocks = gl.cdiv(self.Co, self.BLOCK_M)
        ci_num_blocks = gl.cdiv(self.Ci, self.BLOCK_N)
        return co_num_blocks * self.R * self.S * ci_num_blocks

    @gluon.jit
    def get_num_k_iterations(self):
        return gl.cdiv(self.M_spatial, self.BLOCK_K)

    @gluon.jit
    def get_active_split_k(self):
        total_k_iters = self.get_num_k_iterations()
        k_iters_per_split = gl.cdiv(total_k_iters, self.SPLIT_K)
        return gl.cdiv(total_k_iters, k_iters_per_split)

    @gluon.jit
    def get_num_tiles(self):
        return self.get_num_output_tiles() * self.get_active_split_k()

    @gluon.jit
    def get_program(self, pid):
        active_split_k = self.get_active_split_k()
        split_k_idx = pid % active_split_k
        tile_id = pid // active_split_k

        ci_num_blocks = gl.cdiv(self.Ci, self.BLOCK_N)
        co_num_blocks = gl.cdiv(self.Co, self.BLOCK_M)
        pid_co = tile_id % co_num_blocks
        pid_n = tile_id // co_num_blocks

        ci_block = pid_n % ci_num_blocks
        rs_idx = pid_n // ci_num_blocks
        iter_r = rs_idx // self.S
        iter_s = rs_idx % self.S

        total_k_iters = self.get_num_k_iterations()
        k_iters_per_split = gl.cdiv(total_k_iters, active_split_k)
        k_start = split_k_idx * k_iters_per_split
        remaining_k_iters = total_k_iters - k_start
        zero = gl.to_tensor(0)
        k_iters_this_split = gl.where(
            remaining_k_iters > 0,
            gl.minimum(k_iters_per_split, remaining_k_iters),
            zero,
        )

        return WgradProgram(self, pid_co, ci_block, iter_r, iter_s, split_k_idx, k_start, k_iters_this_split)


@gluon.aggregate
class WgradProgram:
    config: WgradConfig
    pid_co: gl.tensor
    ci_block: gl.tensor
    iter_r: gl.tensor
    iter_s: gl.tensor
    split_k_idx: gl.tensor
    k_start: gl.tensor
    k_iters_this_split: gl.tensor

    @gluon.jit
    def get_co_offset(self):
        return self.pid_co * self.config.BLOCK_M

    @gluon.jit
    def get_ci_offset(self):
        return self.ci_block * self.config.BLOCK_N

    @gluon.jit
    def get_spatial_offsets(self, local_k):
        m_global = (self.k_start + local_k) * self.config.BLOCK_K
        spatial_per_batch = self.config.out_h * self.config.out_w
        m_in_batch = m_global % spatial_per_batch
        batch = m_global // spatial_per_batch
        out_x = m_in_batch % self.config.out_w
        out_y = m_in_batch // self.config.out_w
        return m_global, batch, out_y, out_x

    @gluon.jit
    def get_weight_k_offset(self):
        return (self.iter_r * self.config.S + self.iter_s) * self.config.Ci + self.get_ci_offset()


# ===-----------------------------------------------------------------------===#
# Partition Arguments
# ===-----------------------------------------------------------------------===#


@gluon.aggregate
class PartitionArgs:
    config: WgradConfig
    in_desc: tma.tensor_descriptor_im2col
    grad_out_desc: tma.tensor_descriptor
    grad_weight_ptr: gl.tensor
    grad_weight_stride_0: gl.tensor
    a_bufs: gl.shared_memory_descriptor
    b_bufs: gl.shared_memory_descriptor
    load_empty_bars: gl.shared_memory_descriptor
    load_ready_bars: gl.shared_memory_descriptor
    acc_bufs: tensor_memory_descriptor
    acc_empty_bars: gl.shared_memory_descriptor
    acc_ready_bars: gl.shared_memory_descriptor


# ===-----------------------------------------------------------------------===#
# Warp-Specialized Partitions
# ===-----------------------------------------------------------------------===#


@gluon.jit
def load_partition(p):
    """Load partition: iterate over the persistent wgrad work items assigned to this CTA."""
    config = p.config

    empty_bars = p.load_empty_bars
    ready_bars = p.load_ready_bars
    state = Counter.create(1, empty_bars.shape[0])
    scheduler = PersistentTileScheduler.initialize(config.get_num_tiles())

    for idx in range(scheduler.get_num_tiles()):
        prog = config.get_program(scheduler.get_tile_id(idx))
        co_offset = prog.get_co_offset()
        ci_offset = prog.get_ci_offset()

        for local_k in range(prog.k_iters_this_split):
            m_global, batch, out_y, out_x = prog.get_spatial_offsets(local_k)
            ready_bar = ready_bars.index(state.index)
            mbarrier.wait(empty_bars.index(state.index), state.phase)
            mbarrier.expect(ready_bar, p.grad_out_desc.block_type.nbytes + p.in_desc.block_type.nbytes)

            # A = grad_output: (M_spatial, Co), block [BLOCK_K, BLOCK_M]
            tma.async_load(
                p.grad_out_desc,
                [m_global, co_offset],
                ready_bar,
                p.a_bufs.index(state.index),
            )

            # B = im2col(input): [N, H, W, Ci], block [BLOCK_K, BLOCK_N]
            tma.async_load_im2col(
                p.in_desc,
                [
                    batch,
                    out_y * config.stride_h - config.pad_h,
                    out_x * config.stride_w - config.pad_w,
                    ci_offset,
                ],
                [prog.iter_r.to(tl.int16), prog.iter_s.to(tl.int16)],
                ready_bar,
                p.b_bufs.index(state.index),
            )
            state = state.next()


@gluon.jit
def mma_partition(p):
    """MMA partition: accumulate all split-K work items assigned to this CTA."""
    config = p.config
    load_state = Counter.create(0, p.load_empty_bars.shape[0])
    acc_state = Counter.create(1, p.acc_empty_bars.shape[0])
    scheduler = PersistentTileScheduler.initialize(config.get_num_tiles())

    for idx in range(scheduler.get_num_tiles()):
        prog = config.get_program(scheduler.get_tile_id(idx))

        mbarrier.wait(p.acc_empty_bars.index(acc_state.index), acc_state.phase)
        acc_buf = p.acc_bufs.index(acc_state.index)
        use_acc = False

        for _local_k in range(prog.k_iters_this_split):
            mbarrier.wait(p.load_ready_bars.index(load_state.index), load_state.phase)
            tcgen05_mma(
                p.a_bufs.index(load_state.index).permute((1, 0)),
                p.b_bufs.index(load_state.index),
                acc_buf,
                use_acc=use_acc,
            )
            tcgen05_commit(p.load_empty_bars.index(load_state.index))
            load_state = load_state.next()
            use_acc = True

        tcgen05_commit(p.acc_ready_bars.index(acc_state.index))
        acc_state = acc_state.next()


@gluon.jit
def epilogue_partition(p):
    """Epilogue partition: store the persistent wgrad work items assigned to this CTA."""
    config = p.config
    active_split_k = config.get_active_split_k()
    BLOCK_M: gl.constexpr = config.BLOCK_M
    BLOCK_N: gl.constexpr = config.BLOCK_N
    acc_state = Counter.create(0, p.acc_empty_bars.shape[0])
    scheduler = PersistentTileScheduler.initialize(config.get_num_tiles())

    for idx in range(scheduler.get_num_tiles()):
        prog = config.get_program(scheduler.get_tile_id(idx))
        co_offset = prog.get_co_offset()
        ci_offset = prog.get_ci_offset()
        weight_k_offset = prog.get_weight_k_offset()

        mbarrier.wait(p.acc_ready_bars.index(acc_state.index), acc_state.phase)
        acc = p.acc_bufs.index(acc_state.index).load()
        result = gl.convert_layout(acc, gl.CoalescedLayout())
        mbarrier.arrive(p.acc_empty_bars.index(acc_state.index), count=1)
        acc_state = acc_state.next()

        split_co_offset = gl.where(active_split_k > 1, prog.split_k_idx * config.Co, gl.to_tensor(0))
        offs_m = co_offset + gl.arange(0, BLOCK_M)
        offs_n = weight_k_offset + gl.arange(0, BLOCK_N)

        ci_valid = (ci_offset + gl.arange(0, BLOCK_N)) < config.Ci
        mask = (offs_m[:, None] < config.Co) & (offs_n[None, :] < config.K_GEMM) & ci_valid[None, :]
        store_rows = split_co_offset + offs_m
        offsets = store_rows[:, None] * p.grad_weight_stride_0 + offs_n[None, :]
        gl.store(p.grad_weight_ptr + offsets, result, mask=mask)

    invalidate_mbarrier_ring(p.load_empty_bars)
    invalidate_mbarrier_ring(p.load_ready_bars)
    invalidate_mbarrier_ring(p.acc_empty_bars)
    invalidate_mbarrier_ring(p.acc_ready_bars)


# ===-----------------------------------------------------------------------===#
# Kernel Entry Point
# ===-----------------------------------------------------------------------===#


@gluon.jit(do_not_specialize=[
    "N",
    "R",
    "S",
    "out_h",
    "out_w",
    "stride_h",
    "stride_w",
    "pad_h",
    "pad_w",
])
def conv2d_wgrad_kernel(
    in_desc,
    grad_out_desc,
    grad_weight,
    N,
    Ci,
    Co,
    R,
    S,
    out_h,
    out_w,
    stride_h,
    stride_w,
    pad_h,
    pad_w,
    K_GEMM,
    grad_weight_stride_0,
    BLOCK_M: gl.constexpr,
    BLOCK_N: gl.constexpr,
    BLOCK_K: gl.constexpr,
    SPLIT_K: gl.constexpr,
    num_buffers: gl.constexpr,
    num_acc_buffers: gl.constexpr,
    num_warps: gl.constexpr,
):
    """Warp-specialized wgrad kernel: grad_W = grad_out^T @ im2col(input).

    GEMM dimensions (per CTA):
        M = Co tile                  (output rows)
        N = Ci tile at fixed (r,s)   (output cols)
        K = N_batch * out_h * out_w  (spatial reduction, split across SPLIT_K CTAs)
    """
    M_spatial = N * out_h * out_w
    config = WgradConfig(
        N,
        Ci,
        Co,
        R,
        S,
        gl.to_tensor(out_h),
        gl.to_tensor(out_w),
        gl.to_tensor(stride_h),
        gl.to_tensor(stride_w),
        pad_h,
        pad_w,
        K_GEMM,
        M_spatial,
        BLOCK_M,
        BLOCK_N,
        BLOCK_K,
        SPLIT_K,
        num_buffers,
        num_warps,
    )

    # a_bufs: grad_output tiles [BLOCK_K, BLOCK_M] (spatial × Co)
    # TMA loads from (M_spatial, Co), permuted to [BLOCK_M, BLOCK_K] at MMA call.
    a_smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_M], GL_GEMM_DTYPE)
    # b_bufs: im2col input tiles [BLOCK_K, BLOCK_N] (spatial × Ci)
    b_smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], GL_GEMM_DTYPE)

    a_bufs = gl.allocate_shared_memory(GL_GEMM_DTYPE, [num_buffers, BLOCK_K, BLOCK_M], a_smem_layout)
    b_bufs = gl.allocate_shared_memory(GL_GEMM_DTYPE, [num_buffers, BLOCK_K, BLOCK_N], b_smem_layout)

    load_empty_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
    load_ready_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
    init_mbarrier_ring(load_empty_bars)
    init_mbarrier_ring(load_ready_bars)

    TMEM_BLOCK_M: gl.constexpr = 64 if BLOCK_M == 64 else 128
    tmem_layout: gl.constexpr = TensorMemoryLayout(block=(TMEM_BLOCK_M, BLOCK_N), col_stride=1)
    acc_bufs = allocate_tensor_memory(gl.float32, [num_acc_buffers, BLOCK_M, BLOCK_N], tmem_layout)

    acc_empty_bars = gl.allocate_shared_memory(gl.int64, [num_acc_buffers, 1], mbarrier.MBarrierLayout())
    acc_ready_bars = gl.allocate_shared_memory(gl.int64, [num_acc_buffers, 1], mbarrier.MBarrierLayout())
    init_mbarrier_ring(acc_empty_bars)
    init_mbarrier_ring(acc_ready_bars)

    p = PartitionArgs(
        config,
        in_desc,
        grad_out_desc,
        grad_weight,
        gl.to_tensor(grad_weight_stride_0),
        a_bufs,
        b_bufs,
        load_empty_bars,
        load_ready_bars,
        acc_bufs,
        acc_empty_bars,
        acc_ready_bars,
    )

    gl.warp_specialize([
        (epilogue_partition, (p, )),
        (mma_partition, (p, )),
        (load_partition, (p, )),
    ], [1, 1], [24, 24])


# ===-----------------------------------------------------------------------===#
# Autotuning
# ===-----------------------------------------------------------------------===#


def conv2d_wgrad_get_configs(pre_hook=None):
    return [
        triton.Config(
            {
                "BLOCK_M": block_m,
                "BLOCK_N": block_n,
                "BLOCK_K": block_k,
                "SPLIT_K": split_k,
                "num_buffers": num_buffers,
                "num_acc_buffers": num_acc_buffers,
            },
            num_warps=num_warps,
            pre_hook=pre_hook,
        )
        for block_m in (64, 128)
        for block_n in (64, 128, 256)
        for block_k in (64, )
        for split_k in (1, 2, 4, 8, 16, 32)
        for num_buffers in (3, 4)
        for num_acc_buffers in (2, )
        for num_warps in (4, )
    ]


# ===-----------------------------------------------------------------------===#
# Host-Side Entry Point
# ===-----------------------------------------------------------------------===#


def _prepare_wgrad_problem(input_nhwc, grad_output_nhwc, R, S, stride, padding):
    """Validate inputs, pad channels, and return derived quantities."""
    if input_nhwc.dtype != TORCH_GEMM_DTYPE or grad_output_nhwc.dtype != TORCH_GEMM_DTYPE:
        raise ValueError(
            f"conv2d_wgrad expects bfloat16 input and grad-output tensors, got {input_nhwc.dtype} and {grad_output_nhwc.dtype}"
        )

    stride_h, stride_w = normalize_2d(stride, "stride")
    pad_h, pad_w = normalize_2d(padding, "padding")
    if stride_h <= 0 or stride_w <= 0:
        raise ValueError(f"stride must be positive, got {(stride_h, stride_w)}")
    if pad_h < 0 or pad_w < 0:
        raise ValueError(f"padding must be non-negative, got {(pad_h, pad_w)}")

    N, H, W, Ci_orig = input_nhwc.shape
    N2, out_h, out_w, Co = grad_output_nhwc.shape
    assert N == N2, "Batch size mismatch"

    expected_out_h = (H + 2 * pad_h - R) // stride_h + 1
    expected_out_w = (W + 2 * pad_w - S) // stride_w + 1
    if out_h != expected_out_h or out_w != expected_out_w:
        raise ValueError("Grad-output shape mismatch: expected "
                         f"({N}, {expected_out_h}, {expected_out_w}, {Co}) from input/filter geometry, got "
                         f"({N2}, {out_h}, {out_w}, {Co}).")
    if out_h <= 0 or out_w <= 0:
        raise ValueError("Invalid convolution geometry for wgrad")

    input_nhwc = maybe_pad_ci_for_tma(input_nhwc)
    Ci = input_nhwc.shape[-1]
    K_GEMM = R * S * Ci

    return input_nhwc, grad_output_nhwc, Ci_orig, N, Ci, Co, out_h, out_w, stride_h, stride_w, pad_h, pad_w, K_GEMM


def _allocate_wgrad_output(device, Co, K_GEMM):
    return torch.zeros((Co, K_GEMM), device=device, dtype=torch.float32)


def _make_wgrad_descriptors(input_nhwc, grad_output_nhwc, Co, out_h, out_w, stride_h, stride_w, pad_h, pad_w,
                            input_block_shape, grad_out_block_shape):
    """Create TMA descriptors for wgrad im2col and grad_output."""
    # TMA im2col descriptor for the activation tensor [N, H, W, Ci] in NHWC.
    _, H, W, _ = input_nhwc.shape
    upper_h = (out_h - 1) * stride_h + 1 - H - pad_h
    upper_w = (out_w - 1) * stride_w + 1 - W - pad_w

    input_layout = gl.NVMMASharedLayout.get_default_for(input_block_shape, GL_GEMM_DTYPE)
    in_desc = TensorDescriptorIm2Col(
        base=input_nhwc,
        shape=list(input_nhwc.shape),
        strides=list(input_nhwc.stride()),
        block_shape=input_block_shape,
        layout=input_layout,
        padding="zero",
        element_strides=[1, stride_h, stride_w, 1],
        pixel_box_lower_corner=[-pad_h, -pad_w],
        pixel_box_upper_corner=[upper_h, upper_w],
    )

    # TMA tiled descriptor for grad_output reshaped as (M_spatial, Co).
    M_spatial = input_nhwc.shape[0] * out_h * out_w
    grad_out_2d = grad_output_nhwc.reshape(M_spatial, Co)
    grad_out_layout = gl.NVMMASharedLayout.get_default_for(grad_out_block_shape, GL_GEMM_DTYPE)
    grad_out_desc = TensorDescriptor.from_tensor(grad_out_2d, grad_out_block_shape, grad_out_layout)

    return in_desc, grad_out_desc


def _make_grid(num_sms, M_spatial, Co, Ci, R, S):

    def grid(meta):
        co_blocks = triton.cdiv(Co, meta["BLOCK_M"])
        ci_blocks = triton.cdiv(Ci, meta["BLOCK_N"])
        total_k_iters = triton.cdiv(M_spatial, meta["BLOCK_K"])
        k_iters_per_split = triton.cdiv(total_k_iters, meta["SPLIT_K"])
        active_split_k = triton.cdiv(total_k_iters, k_iters_per_split)
        total_tiles = co_blocks * R * S * ci_blocks * active_split_k
        return (min(num_sms, total_tiles), )

    return grid


def _get_active_split_k(M_spatial, BLOCK_K, SPLIT_K):
    total_k_iters = triton.cdiv(M_spatial, BLOCK_K)
    k_iters_per_split = triton.cdiv(total_k_iters, SPLIT_K)
    return triton.cdiv(total_k_iters, k_iters_per_split)


def _get_safe_wgrad_active_split_k(M_spatial, Co, K_GEMM, kernel_meta):
    active_split_k = _get_active_split_k(M_spatial, kernel_meta["BLOCK_K"], kernel_meta["SPLIT_K"])
    if active_split_k > 1:
        # The split-K workspace is indexed as row * stride + col inside the kernel.
        # Very large workspaces can exceed the addressing range supported by the generated code.
        workspace_elems = active_split_k * Co * K_GEMM
        if workspace_elems > (2**31 - 1):
            raise ValueError("wgrad split-K workspace exceeds safe indexing range: "
                             f"active_split_k={active_split_k}, Co={Co}, K_GEMM={K_GEMM}")
    return active_split_k


def _allocate_wgrad_split_k_workspace(device, active_split_k, Co, K_GEMM):
    return torch.empty((active_split_k * Co, K_GEMM), device=device, dtype=torch.float32)


_wgrad_autotune_cache = {}


def _make_wgrad_autotune_key(
    device,
    num_sms,
    N,
    Ci,
    Co,
    R,
    S,
    out_h,
    out_w,
    stride_h,
    stride_w,
    pad_h,
    pad_w,
):
    return (
        torch.cuda.get_device_capability(device),
        num_sms,
        N,
        Ci,
        Co,
        R,
        S,
        out_h,
        out_w,
        stride_h,
        stride_w,
        pad_h,
        pad_w,
    )


def _make_wgrad_runner(
    input_nhwc,
    grad_output_nhwc,
    grad_weight_flat,
    *,
    N,
    Ci,
    Co,
    R,
    S,
    out_h,
    out_w,
    stride_h,
    stride_w,
    pad_h,
    pad_w,
    K_GEMM,
    num_sms,
    kernel_meta,
):
    M_spatial = N * out_h * out_w
    active_split_k = _get_safe_wgrad_active_split_k(M_spatial, Co, K_GEMM, kernel_meta)
    uses_split_k_workspace = active_split_k > 1
    launch_output = grad_weight_flat
    if uses_split_k_workspace:
        launch_output = _allocate_wgrad_split_k_workspace(input_nhwc.device, active_split_k, Co, K_GEMM)

    in_desc, grad_out_desc = _make_wgrad_descriptors(
        input_nhwc,
        grad_output_nhwc,
        Co,
        out_h,
        out_w,
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        [kernel_meta["BLOCK_K"], kernel_meta["BLOCK_N"]],
        [kernel_meta["BLOCK_K"], kernel_meta["BLOCK_M"]],
    )
    grid = _make_grid(num_sms, M_spatial, Co, Ci, R, S)

    def run():
        _launch_wgrad(
            conv2d_wgrad_kernel,
            grid,
            in_desc=in_desc,
            grad_out_desc=grad_out_desc,
            grad_weight=launch_output,
            N=N,
            Ci=Ci,
            Co=Co,
            R=R,
            S=S,
            out_h=out_h,
            out_w=out_w,
            stride_h=stride_h,
            stride_w=stride_w,
            pad_h=pad_h,
            pad_w=pad_w,
            K_GEMM=K_GEMM,
            kernel_meta=kernel_meta,
        )
        if uses_split_k_workspace:
            _reduce_wgrad_split_k_partials(launch_output, grad_weight_flat, Co, K_GEMM, active_split_k)

    return run


def _benchmark_wgrad_config(
    input_nhwc,
    grad_output_nhwc,
    *,
    N,
    Ci,
    Co,
    R,
    S,
    out_h,
    out_w,
    stride_h,
    stride_w,
    pad_h,
    pad_w,
    K_GEMM,
    num_sms,
    kernel_meta,
):
    try:
        grad_weight_flat = torch.empty((Co, K_GEMM), device=input_nhwc.device, dtype=torch.float32)
        run = _make_wgrad_runner(
            input_nhwc,
            grad_output_nhwc,
            grad_weight_flat,
            N=N,
            Ci=Ci,
            Co=Co,
            R=R,
            S=S,
            out_h=out_h,
            out_w=out_w,
            stride_h=stride_h,
            stride_w=stride_w,
            pad_h=pad_h,
            pad_w=pad_w,
            K_GEMM=K_GEMM,
            num_sms=num_sms,
            kernel_meta=kernel_meta,
        )
        run()
        torch.cuda.synchronize()
        return triton.testing.do_bench(run)
    except Exception:
        return float("inf")


def _select_wgrad_kernel_meta(
    input_nhwc,
    grad_output_nhwc,
    *,
    N,
    Ci,
    Co,
    R,
    S,
    out_h,
    out_w,
    stride_h,
    stride_w,
    pad_h,
    pad_w,
    K_GEMM,
    num_sms,
):
    cache_key = _make_wgrad_autotune_key(input_nhwc.device, num_sms, N, Ci, Co, R, S, out_h, out_w, stride_h, stride_w,
                                         pad_h, pad_w)
    cached = _wgrad_autotune_cache.get(cache_key)
    if cached is not None:
        return dict(cached)

    best_ms = float("inf")
    best_kernel_meta = None
    for config in conv2d_wgrad_get_configs():
        kernel_meta = config.all_kwargs()
        ms = _benchmark_wgrad_config(
            input_nhwc,
            grad_output_nhwc,
            N=N,
            Ci=Ci,
            Co=Co,
            R=R,
            S=S,
            out_h=out_h,
            out_w=out_w,
            stride_h=stride_h,
            stride_w=stride_w,
            pad_h=pad_h,
            pad_w=pad_w,
            K_GEMM=K_GEMM,
            num_sms=num_sms,
            kernel_meta=kernel_meta,
        )
        if ms < best_ms:
            best_ms = ms
            best_kernel_meta = dict(kernel_meta)

    if best_kernel_meta is None:
        raise RuntimeError("Failed to autotune conv2d_wgrad: no valid kernel configurations.")

    _wgrad_autotune_cache[cache_key] = dict(best_kernel_meta)
    return dict(best_kernel_meta)


@triton.jit
def reduce_split_k_partials_kernel(
    partial_ptr,
    grad_weight_ptr,
    partial_stride_0,
    grad_weight_stride_0,
    Co,
    K_GEMM,
    ACTIVE_SPLIT_K: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    pid_m = tl.program_id(axis=0)
    pid_n = tl.program_id(axis=1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    mask = (offs_m[:, None] < Co) & (offs_n[None, :] < K_GEMM)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for split_k_idx in range(ACTIVE_SPLIT_K):
        partial_rows = split_k_idx * Co + offs_m
        partial_offsets = partial_rows[:, None] * partial_stride_0 + offs_n[None, :]
        acc += tl.load(partial_ptr + partial_offsets, mask=mask, other=0.0)

    grad_weight_offsets = offs_m[:, None] * grad_weight_stride_0 + offs_n[None, :]
    tl.store(grad_weight_ptr + grad_weight_offsets, acc, mask=mask)


def _reduce_wgrad_split_k_partials(partials, grad_weight_flat, Co, K_GEMM, active_split_k):
    BLOCK_M = 64
    BLOCK_N = 64
    grid = (triton.cdiv(Co, BLOCK_M), triton.cdiv(K_GEMM, BLOCK_N))
    reduce_split_k_partials_kernel[grid](
        partials,
        grad_weight_flat,
        partials.stride(0),
        grad_weight_flat.stride(0),
        Co,
        K_GEMM,
        ACTIVE_SPLIT_K=active_split_k,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        num_warps=4,
    )


def _launch_wgrad(
    kernel,
    grid,
    *,
    in_desc,
    grad_out_desc,
    grad_weight,
    N,
    Ci,
    Co,
    R,
    S,
    out_h,
    out_w,
    stride_h,
    stride_w,
    pad_h,
    pad_w,
    K_GEMM,
    kernel_meta=None,
):
    if kernel_meta is None:
        kernel_meta = {}

    kernel[grid](
        in_desc,
        grad_out_desc,
        grad_weight,
        N,
        Ci,
        Co,
        R,
        S,
        out_h,
        out_w,
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        K_GEMM,
        grad_weight.stride(0),
        **kernel_meta,
    )


def _finalize_wgrad_output(grad_weight_flat, Co, R, S, Ci, Ci_orig):
    result = grad_weight_flat.reshape(Co, R, S, Ci).to(TORCH_GEMM_DTYPE)
    if Ci != Ci_orig:
        result = result[:, :, :, :Ci_orig].contiguous()
    return result


def conv2d_wgrad(input_nhwc, grad_output_nhwc, R, S, stride=1, padding=0):
    """Production wgrad entrypoint.

    Selects the best kernel configuration with host-side autotuning, then runs
    deterministic two-pass split-K when reduction is needed.
    """
    (input_nhwc, grad_output_nhwc, Ci_orig, N, Ci, Co,
     out_h, out_w, stride_h, stride_w, pad_h, pad_w, K_GEMM) = \
        _prepare_wgrad_problem(input_nhwc, grad_output_nhwc, R, S, stride, padding)
    grad_weight_flat = _allocate_wgrad_output(input_nhwc.device, Co, K_GEMM)

    num_sms = torch.cuda.get_device_properties(input_nhwc.device).multi_processor_count

    kernel_meta = _select_wgrad_kernel_meta(
        input_nhwc,
        grad_output_nhwc,
        N=N,
        Ci=Ci,
        Co=Co,
        R=R,
        S=S,
        out_h=out_h,
        out_w=out_w,
        stride_h=stride_h,
        stride_w=stride_w,
        pad_h=pad_h,
        pad_w=pad_w,
        K_GEMM=K_GEMM,
        num_sms=num_sms,
    )
    run = _make_wgrad_runner(
        input_nhwc,
        grad_output_nhwc,
        grad_weight_flat,
        N=N,
        Ci=Ci,
        Co=Co,
        R=R,
        S=S,
        out_h=out_h,
        out_w=out_w,
        stride_h=stride_h,
        stride_w=stride_w,
        pad_h=pad_h,
        pad_w=pad_w,
        K_GEMM=K_GEMM,
        num_sms=num_sms,
        kernel_meta=kernel_meta,
    )
    run()

    return _finalize_wgrad_output(grad_weight_flat, Co, R, S, Ci, Ci_orig)


def _make_wgrad_fixed_kernel_meta(SPLIT_K, num_buffers, num_warps):
    # Keep the fixed path on a tile shape that is also covered by autotune configs.
    return {
        "BLOCK_M": 128,
        "BLOCK_N": 256,
        "BLOCK_K": 64,
        "SPLIT_K": SPLIT_K,
        "num_buffers": num_buffers,
        "num_acc_buffers": 2,
        "num_warps": num_warps,
    }


def conv2d_wgrad_fixed(input_nhwc, grad_output_nhwc, R, S, stride=1, padding=0, num_buffers=2, num_warps=4, SPLIT_K=1):
    """Fixed-config wgrad entrypoint used for CI and debugging.

    Runs the kernel with a fixed supported tile shape instead of autotuning,
    while still using deterministic two-pass split-K when reduction is needed.
    """
    (input_nhwc, grad_output_nhwc, Ci_orig, N, Ci, Co,
     out_h, out_w, stride_h, stride_w, pad_h, pad_w, K_GEMM) = \
        _prepare_wgrad_problem(input_nhwc, grad_output_nhwc, R, S, stride, padding)
    grad_weight_flat = _allocate_wgrad_output(input_nhwc.device, Co, K_GEMM)

    num_sms = torch.cuda.get_device_properties(input_nhwc.device).multi_processor_count
    kernel_meta = _make_wgrad_fixed_kernel_meta(SPLIT_K, num_buffers, num_warps)
    run = _make_wgrad_runner(
        input_nhwc,
        grad_output_nhwc,
        grad_weight_flat,
        N=N,
        Ci=Ci,
        Co=Co,
        R=R,
        S=S,
        out_h=out_h,
        out_w=out_w,
        stride_h=stride_h,
        stride_w=stride_w,
        pad_h=pad_h,
        pad_w=pad_w,
        K_GEMM=K_GEMM,
        num_sms=num_sms,
        kernel_meta=kernel_meta,
    )
    run()

    return _finalize_wgrad_output(grad_weight_flat, Co, R, S, Ci, Ci_orig)


# ===-----------------------------------------------------------------------===#
# Unit Tests
# ===-----------------------------------------------------------------------===#


def _assert_wgrad_correct(wgrad_fn, N, Ci, H, W, Co, R, S, stride, padding, **kwargs):
    """Run wgrad_fn and compare against PyTorch autograd reference."""
    torch.manual_seed(0)
    stride_h, stride_w = normalize_2d(stride, "stride")
    pad_h, pad_w = normalize_2d(padding, "padding")

    x_nchw = torch.randn((N, Ci, H, W), device="cuda", dtype=TORCH_GEMM_DTYPE)
    x_nhwc = x_nchw.permute(0, 2, 3, 1).contiguous()

    out_h = (H + 2 * pad_h - R) // stride_h + 1
    out_w = (W + 2 * pad_w - S) // stride_w + 1

    grad_out_nchw = torch.randn((N, Co, out_h, out_w), device="cuda", dtype=TORCH_GEMM_DTYPE)
    grad_out_nhwc = grad_out_nchw.permute(0, 2, 3, 1).contiguous()

    w_nchw = torch.randn((Co, Ci, R, S), device="cuda", dtype=TORCH_GEMM_DTYPE)
    w_ref = w_nchw.detach().requires_grad_(True)
    out_ref = torch.nn.functional.conv2d(x_nchw, w_ref, stride=(stride_h, stride_w), padding=(pad_h, pad_w))
    out_ref.backward(grad_out_nchw)
    ref_grad_w_nhwc = w_ref.grad.permute(0, 2, 3, 1).contiguous()

    triton_grad_w = wgrad_fn(x_nhwc, grad_out_nhwc, R, S, stride=stride, padding=padding, **kwargs)
    torch.testing.assert_close(triton_grad_w, ref_grad_w_nhwc, atol=1, rtol=0.01)


@pytest.mark.parametrize("wgrad_fn,N,Ci,H,W,Co,R,S,stride,padding", [
    *[(conv2d_wgrad_fixed, N, Ci, H, W, Co, R, S, stride, padding)
      for N in (1, 128)
      for H, W in ((64, 64), (64, 32))
      for Ci, Co in ((128, 128), (384, 384), (128, 384))
      for R, S in ((1, 1), (2, 2), (3, 3), (1, 3))
      for stride in (1, 2, 3)
      for padding in (0, 1)], (conv2d_wgrad_fixed, 16, 5, 32, 32, 96, 3, 3, 1, 1),  # padded channels
    (conv2d_wgrad_fixed, 16, 96, 1, 8, 128, 1, 2, (1, 2), 0),  # asymmetric stride
    (conv2d_wgrad_fixed, 16, 512, 2, 2, 768, 2, 2, (2, 2), 0),  # small spatial
])
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell GPU (SM 10.x)")
def test_op(wgrad_fn, N, Ci, H, W, Co, R, S, stride, padding):
    _assert_wgrad_correct(wgrad_fn, N, Ci, H, W, Co, R, S, stride, padding)


# ===-----------------------------------------------------------------------===#
# Benchmarking
# ===-----------------------------------------------------------------------===#

BATCH = [128]
CHANNELS = [(384, 384)]
SPATIAL = [(64, 64)]
FILTER = [(3, 3)]
STRIDE = [1]
PADDING = [1]


def _make_bench_inputs(N, H, W, Ci, Co, R, S, stride_val, pad_val):
    torch.manual_seed(0)
    out_h = (H + 2 * pad_val - R) // stride_val + 1
    out_w = (W + 2 * pad_val - S) // stride_val + 1
    x_nchw = torch.randn((N, Ci, H, W), device="cuda", dtype=TORCH_GEMM_DTYPE)
    x_nhwc = x_nchw.permute(0, 2, 3, 1).contiguous()
    grad_out_nchw = torch.randn((N, Co, out_h, out_w), device="cuda", dtype=TORCH_GEMM_DTYPE)
    grad_out_nhwc = grad_out_nchw.permute(0, 2, 3, 1).contiguous()
    return x_nchw, x_nhwc, grad_out_nchw, grad_out_nhwc


def _benchmark_tflops(fn, *, N, H, W, Ci, Co, R, S, stride_val, pad_val):
    ms = triton.testing.do_bench(fn)
    out_h = (H + 2 * pad_val - R) // stride_val + 1
    out_w = (W + 2 * pad_val - S) // stride_val + 1
    flops = 2.0 * N * out_h * out_w * Co * Ci * R * S
    return flops * 1e-12 / (ms * 1e-3)


bench_configs = []
for N, (Ci, Co), (H, W), (R, S), stride_val, pad_val in [(N, ch, sp, f, s, p)
                                                         for N in BATCH
                                                         for ch in CHANNELS
                                                         for sp in SPATIAL
                                                         for f in FILTER
                                                         for s in STRIDE
                                                         for p in PADDING]:
    bench_configs.append(
        triton.testing.Benchmark(
            x_names=["kernel"],
            x_vals=["autotuned"],
            line_arg="provider",
            line_vals=["gluon", "torch"],
            line_names=["Gluon (autotuned)", "PyTorch"],
            styles=[("green", "-"), ("blue", "-")],
            ylabel="TFLOPS",
            plot_name=f"Wgrad N={N} Ci={Ci} Co={Co} H={H} W={W} R={R} S={S} stride={stride_val} pad={pad_val}",
            args={
                "N": N,
                "H": H,
                "W": W,
                "Ci": Ci,
                "Co": Co,
                "R": R,
                "S": S,
                "stride_val": stride_val,
                "pad_val": pad_val,
            },
        ))


@triton.testing.perf_report(bench_configs)
def bench(N, H, W, Ci, Co, R, S, stride_val, pad_val, kernel, provider):
    x_nchw, x_nhwc, grad_out_nchw, grad_out_nhwc = \
        _make_bench_inputs(N, H, W, Ci, Co, R, S, stride_val, pad_val)

    if provider == "gluon":
        fn = lambda: conv2d_wgrad(x_nhwc, grad_out_nhwc, R, S, stride=stride_val, padding=pad_val)
    elif provider == "torch":
        w_nchw = torch.randn((Co, Ci, R, S), device="cuda", dtype=TORCH_GEMM_DTYPE)
        fn = lambda: torch.ops.aten.convolution_backward(
            grad_out_nchw,
            x_nchw,
            w_nchw,
            bias_sizes=None,
            stride=[stride_val, stride_val],
            padding=[pad_val, pad_val],
            dilation=[1, 1],
            transposed=False,
            output_padding=[0, 0],
            groups=1,
            output_mask=[False, True, False],
        )
    else:
        raise ValueError(f"Unsupported provider: {provider}")

    return _benchmark_tflops(
        fn,
        N=N,
        H=H,
        W=W,
        Ci=Ci,
        Co=Co,
        R=R,
        S=S,
        stride_val=stride_val,
        pad_val=pad_val,
    )


if __name__ == "__main__":
    bench.run(save_path=".", print_data=True)