Conv Fprop

This example can be found at python/examples/gluon/02-conv-fprop.py.

import importlib.util
import sys
from pathlib import Path

import pytest
import torch

import triton
import triton.language as tl

from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor, TensorDescriptorIm2Col
from triton.experimental.gluon.language.nvidia.hopper import tma, mbarrier
from triton.experimental.gluon.language.nvidia.blackwell import (
    TensorMemoryLayout,
    allocate_tensor_memory,
    tensor_memory_descriptor,
    tcgen05_mma,
    tcgen05_commit,
)


def _load_conv_common():
    module_name = "triton_examples_gluon_conv_common"
    module = sys.modules.get(module_name)
    if module is not None:
        return module

    module_path = Path(__file__).with_name("02-conv-common.py")
    spec = importlib.util.spec_from_file_location(module_name, module_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Unable to load shared conv helpers from {module_path}")

    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


_conv_common = _load_conv_common()

# ===-----------------------------------------------------------------------===#
# Utilities
# ===-----------------------------------------------------------------------===#

Counter = _conv_common.Counter
GL_GEMM_DTYPE = _conv_common.GL_GEMM_DTYPE
PersistentTileScheduler = _conv_common.PersistentTileScheduler
TORCH_GEMM_DTYPE = _conv_common.TORCH_GEMM_DTYPE
init_mbarrier_ring = _conv_common.init_mbarrier_ring
invalidate_mbarrier_ring = _conv_common.invalidate_mbarrier_ring
is_blackwell = _conv_common.is_blackwell
is_cuda = _conv_common.is_cuda
maybe_pad_ci_for_tma = _conv_common.maybe_pad_channel_dims_for_tma
normalize_2d = _conv_common.normalize_2d

# ===-----------------------------------------------------------------------===#
# Convolution Configuration
# ===-----------------------------------------------------------------------===#

# Convolution parameter naming convention:
#   N   = batch size
#   H,W = input spatial dims
#   Ci  = input channels   (part of GEMM-K reduction: K_GEMM = R * S * Ci)
#   Co  = output channels  (maps to GEMM-N dimension: N_GEMM = Co)
#   R,S = filter height, width
#
# GEMM mapping:
#   M_GEMM = N * out_h * out_w   (output spatial positions)
#   N_GEMM = Co                  (output channels)
#   K_GEMM = R * S * Ci          (reduction over filter x input channels)


@gluon.aggregate
class ConvConfig:
    N: gl.tensor
    H: gl.tensor
    W: gl.tensor
    Ci: gl.tensor
    Co: gl.tensor
    R: gl.tensor
    S: gl.tensor
    out_h: gl.tensor
    out_w: gl.tensor
    stride_h: gl.tensor
    stride_w: gl.tensor
    pad_h: gl.tensor
    pad_w: gl.tensor
    output_stride_n: gl.tensor
    output_stride_h: gl.tensor
    output_stride_w: gl.tensor
    M_GEMM: gl.tensor

    BLOCK_M: gl.constexpr
    BLOCK_N: gl.constexpr
    BLOCK_K: gl.constexpr
    GROUP_SIZE_M: gl.constexpr
    num_buffers: gl.constexpr
    num_warps: gl.constexpr

    @gluon.jit
    def get_program(self, pid):
        """Compute tile coordinates from program ID with grouped ordering."""
        M_GEMM = self.M_GEMM
        N_GEMM = self.Co

        num_pid_m = gl.cdiv(M_GEMM, self.BLOCK_M)
        num_pid_n = gl.cdiv(N_GEMM, self.BLOCK_N)
        num_pid_in_group = self.GROUP_SIZE_M * num_pid_n
        group_id = pid // num_pid_in_group
        first_pid_m = group_id * self.GROUP_SIZE_M
        group_size_m = gl.minimum(num_pid_m - first_pid_m, self.GROUP_SIZE_M)
        pid_m = first_pid_m + (pid % group_size_m)
        pid_n = (pid % num_pid_in_group) // group_size_m

        return ConvProgram(self, pid_m, pid_n)

    @gluon.jit
    def get_num_tiles(self):
        return gl.cdiv(self.M_GEMM, self.BLOCK_M) * gl.cdiv(self.Co, self.BLOCK_N)

    @gluon.jit
    def get_num_k_iterations(self):
        return self.R * self.S * gl.cdiv(self.Ci, self.BLOCK_K)


@gluon.aggregate
class ConvProgram:
    config: ConvConfig
    pid_m: gl.tensor
    pid_n: gl.tensor

    @gluon.jit
    def get_m_offsets(self):
        """Decompose M-tile offset into (batch, out_y, out_x)."""
        offs_m = self.pid_m * self.config.BLOCK_M
        config = self.config
        out_x = offs_m % config.out_w
        out_y = (offs_m // config.out_w) % config.out_h
        batch_id = (offs_m // config.out_w) // config.out_h
        return batch_id, out_y, out_x


# ===-----------------------------------------------------------------------===#
# Partition Arguments
# ===-----------------------------------------------------------------------===#


@gluon.aggregate
class PartitionArgs:
    config: ConvConfig
    in_desc: tma.tensor_descriptor_im2col
    weight_desc: tma.tensor_descriptor
    output_ptr: gl.tensor
    a_bufs: gl.shared_memory_descriptor
    b_bufs: gl.shared_memory_descriptor
    load_empty_bars: gl.shared_memory_descriptor
    load_ready_bars: gl.shared_memory_descriptor
    acc_bufs: tensor_memory_descriptor
    acc_empty_bars: gl.shared_memory_descriptor
    acc_ready_bars: gl.shared_memory_descriptor


# ===-----------------------------------------------------------------------===#
# Warp-Specialized Partitions
# ===-----------------------------------------------------------------------===#


@gluon.jit
def load_partition(p):
    """Load partition: iterate over this CTA's assigned output tiles."""
    config = p.config
    BLOCK_K: gl.constexpr = config.BLOCK_K

    empty_bars = p.load_empty_bars
    ready_bars = p.load_ready_bars
    state = Counter.create(1, empty_bars.shape[0])

    num_rs = config.R * config.S
    num_k_iter = config.get_num_k_iterations()

    scheduler = PersistentTileScheduler.initialize(config.get_num_tiles())
    for idx in range(scheduler.get_num_tiles()):
        prog = config.get_program(scheduler.get_tile_id(idx))
        batch_id, out_y, out_x = prog.get_m_offsets()

        for k_iter in range(num_k_iter):
            iter_ci = k_iter // num_rs
            remain_rs = k_iter % num_rs
            iter_s = remain_rs % config.S
            iter_r = remain_rs // config.S

            ready_bar = ready_bars.index(state.index)
            mbarrier.wait(empty_bars.index(state.index), state.phase)
            mbarrier.expect(ready_bar, p.in_desc.block_type.nbytes + p.weight_desc.block_type.nbytes)

            tma.async_load_im2col(
                p.in_desc,
                [
                    batch_id,
                    out_y * config.stride_h - config.pad_h,
                    out_x * config.stride_w - config.pad_w,
                    iter_ci * BLOCK_K,
                ],
                [iter_r.to(tl.int16), iter_s.to(tl.int16)],
                ready_bar,
                p.a_bufs.index(state.index),
            )

            k_offset = (iter_r * config.S + iter_s) * config.Ci + iter_ci * BLOCK_K
            tma.async_load(
                p.weight_desc,
                [prog.pid_n * config.BLOCK_N, k_offset],
                ready_bar,
                p.b_bufs.index(state.index),
            )
            state = state.next()


@gluon.jit
def mma_partition(p):
    """MMA partition: accumulate over all tiles assigned to this CTA."""
    config = p.config

    num_k_iter = config.get_num_k_iterations()
    load_state = Counter.create(0, p.load_empty_bars.shape[0])
    acc_state = Counter.create(1, p.acc_empty_bars.shape[0])

    scheduler = PersistentTileScheduler.initialize(config.get_num_tiles())
    for _ in range(scheduler.get_num_tiles()):
        mbarrier.wait(p.acc_empty_bars.index(acc_state.index), acc_state.phase)
        acc_buf = p.acc_bufs.index(acc_state.index)
        use_acc = False

        for _k_iter in range(num_k_iter):
            mbarrier.wait(p.load_ready_bars.index(load_state.index), load_state.phase)
            tcgen05_mma(
                p.a_bufs.index(load_state.index),
                p.b_bufs.index(load_state.index).permute((1, 0)),
                acc_buf,
                use_acc=use_acc,
            )
            tcgen05_commit(p.load_empty_bars.index(load_state.index))
            load_state = load_state.next()
            use_acc = True

        tcgen05_commit(p.acc_ready_bars.index(acc_state.index))
        acc_state = acc_state.next()


@gluon.jit
def epilogue_partition(p):
    """Epilogue partition: store all tiles assigned to this CTA."""
    config = p.config
    BLOCK_M: gl.constexpr = config.BLOCK_M
    BLOCK_N: gl.constexpr = config.BLOCK_N
    M_GEMM = config.M_GEMM
    N_GEMM = config.Co

    acc_state = Counter.create(0, p.acc_empty_bars.shape[0])
    scheduler = PersistentTileScheduler.initialize(config.get_num_tiles())
    for idx in range(scheduler.get_num_tiles()):
        prog = config.get_program(scheduler.get_tile_id(idx))

        mbarrier.wait(p.acc_ready_bars.index(acc_state.index), acc_state.phase)
        acc = p.acc_bufs.index(acc_state.index).load()
        result = gl.convert_layout(acc.to(GL_GEMM_DTYPE), gl.CoalescedLayout())
        mbarrier.arrive(p.acc_empty_bars.index(acc_state.index), count=1)
        acc_state = acc_state.next()

        offs_m = prog.pid_m * BLOCK_M + gl.arange(0, BLOCK_M)
        offs_n = prog.pid_n * BLOCK_N + gl.arange(0, BLOCK_N)

        c_out_x = offs_m % config.out_w
        c_out_y = (offs_m // config.out_w) % config.out_h
        c_batch = (offs_m // config.out_w) // config.out_h

        c_offsets = (c_batch[:, None] * config.output_stride_n + c_out_y[:, None] * config.output_stride_h +
                     c_out_x[:, None] * config.output_stride_w + offs_n[None, :])
        c_mask = (offs_m[:, None] < M_GEMM) & (offs_n[None, :] < N_GEMM)
        gl.store(p.output_ptr + c_offsets, result, mask=c_mask)

    invalidate_mbarrier_ring(p.load_empty_bars)
    invalidate_mbarrier_ring(p.load_ready_bars)
    invalidate_mbarrier_ring(p.acc_empty_bars)
    invalidate_mbarrier_ring(p.acc_ready_bars)


# ===-----------------------------------------------------------------------===#
# Kernel Entry Point
# ===-----------------------------------------------------------------------===#


@gluon.jit(do_not_specialize=[
    "N",
    "H",
    "W",
    "R",
    "S",
    "pad_h",
    "pad_w",
])
def conv2d_fprop_kernel(
    in_desc,
    weight_desc,
    output,
    N,
    H,
    W,
    Ci,
    Co,
    R,
    S,
    out_h,
    out_w,
    output_stride_n,
    output_stride_h,
    output_stride_w,
    stride_h,
    stride_w,
    pad_h,
    pad_w,
    BLOCK_M: gl.constexpr,
    BLOCK_N: gl.constexpr,
    BLOCK_K: gl.constexpr,
    GROUP_SIZE_M: gl.constexpr,
    num_buffers: gl.constexpr,
    num_acc_buffers: gl.constexpr,
    num_warps: gl.constexpr,
):
    """Warp-specialized forward convolution kernel."""
    M_GEMM = N * out_h * out_w
    config = ConvConfig(
        N,
        H,
        W,
        Ci,
        Co,
        R,
        S,
        gl.to_tensor(out_h),
        gl.to_tensor(out_w),
        gl.to_tensor(stride_h),
        gl.to_tensor(stride_w),
        pad_h,
        pad_w,
        gl.to_tensor(output_stride_n),
        gl.to_tensor(output_stride_h),
        gl.to_tensor(output_stride_w),
        M_GEMM,
        BLOCK_M,
        BLOCK_N,
        BLOCK_K,
        GROUP_SIZE_M,
        num_buffers,
        num_warps,
    )

    a_smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], GL_GEMM_DTYPE)
    b_smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK_N, BLOCK_K], GL_GEMM_DTYPE)

    a_bufs = gl.allocate_shared_memory(GL_GEMM_DTYPE, [num_buffers, BLOCK_M, BLOCK_K], a_smem_layout)
    b_bufs = gl.allocate_shared_memory(GL_GEMM_DTYPE, [num_buffers, BLOCK_N, BLOCK_K], b_smem_layout)

    load_empty_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
    load_ready_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
    init_mbarrier_ring(load_empty_bars)
    init_mbarrier_ring(load_ready_bars)

    TMEM_BLOCK_M: gl.constexpr = 64 if BLOCK_M == 64 else 128
    tmem_layout: gl.constexpr = TensorMemoryLayout(block=(TMEM_BLOCK_M, BLOCK_N), col_stride=1)
    # Smaller tiles can profit from a double-buffered accumulator ring, but
    # large 256x256 tiles exceed Blackwell's TMEM budget unless the ring depth
    # is reduced to 1.
    acc_bufs = allocate_tensor_memory(gl.float32, [num_acc_buffers, BLOCK_M, BLOCK_N], tmem_layout)

    acc_empty_bars = gl.allocate_shared_memory(gl.int64, [num_acc_buffers, 1], mbarrier.MBarrierLayout())
    acc_ready_bars = gl.allocate_shared_memory(gl.int64, [num_acc_buffers, 1], mbarrier.MBarrierLayout())
    init_mbarrier_ring(acc_empty_bars)
    init_mbarrier_ring(acc_ready_bars)

    p = PartitionArgs(
        config,
        in_desc,
        weight_desc,
        output,
        a_bufs,
        b_bufs,
        load_empty_bars,
        load_ready_bars,
        acc_bufs,
        acc_empty_bars,
        acc_ready_bars,
    )

    gl.warp_specialize([
        (epilogue_partition, (p, )),
        (mma_partition, (p, )),
        (load_partition, (p, )),
    ], [1, 1], [24, 24])


def conv2d_fprop_get_configs(pre_hook=None):
    return [
        triton.Config(
            {
                "BLOCK_M": block_m,
                "BLOCK_N": block_n,
                "BLOCK_K": block_k,
                "GROUP_SIZE_M": group_size_m,
                "num_buffers": num_buffers,
                "num_acc_buffers": num_acc_buffers,
            },
            num_warps=num_warps,
            pre_hook=pre_hook,
        )
        for block_m in (64, 128)
        for block_n in (8, 32, 128, 256)
        for block_k in (64, )
        for group_size_m in (4, )
        for num_buffers in (3, 4, 5)
        for num_acc_buffers in (2, )
        for num_warps in (4, )
    ]


def conv2d_fprop_tma_set_block_size_hook(nargs):
    in_block_shape = [nargs["BLOCK_M"], nargs["BLOCK_K"]]
    weight_block_shape = [nargs["BLOCK_N"], nargs["BLOCK_K"]]

    nargs["in_desc"].block_shape = in_block_shape
    nargs["in_desc"].layout = gl.NVMMASharedLayout.get_default_for(in_block_shape, GL_GEMM_DTYPE)

    nargs["weight_desc"].block_shape = weight_block_shape
    nargs["weight_desc"].layout = gl.NVMMASharedLayout.get_default_for(weight_block_shape, GL_GEMM_DTYPE)


# Key on the effective implicit-GEMM/convolution geometry instead of the full
# raw input shape. `out_h/out_w` already encode the impact of H/W/padding on
# the launch shape, so keeping all of them would only fragment the autotune
# cache without exposing meaningfully different tile choices.
conv2d_fprop_autotuned_kernel = triton.autotune(
    configs=conv2d_fprop_get_configs(pre_hook=conv2d_fprop_tma_set_block_size_hook),
    key=["out_h", "out_w", "stride_h", "stride_w"],
)(conv2d_fprop_kernel)

# ===-----------------------------------------------------------------------===#
# Host-Side Entry Point
# ===-----------------------------------------------------------------------===#


def _prepare_conv_fprop_inputs(input_tensor, weight_tensor, stride, padding):
    N, H, W, Ci = input_tensor.shape
    Co, R, S, Ci_w = weight_tensor.shape
    assert Ci == Ci_w, "Input and weight channel dimensions must match"
    if input_tensor.dtype != TORCH_GEMM_DTYPE or weight_tensor.dtype != TORCH_GEMM_DTYPE:
        raise ValueError(
            f"conv2d_fprop expects bfloat16 input/weight tensors, got {input_tensor.dtype} and {weight_tensor.dtype}")
    stride_h, stride_w = normalize_2d(stride, "stride")
    pad_h, pad_w = normalize_2d(padding, "padding")
    if stride_h <= 0 or stride_w <= 0:
        raise ValueError(f"stride must be positive, got {(stride_h, stride_w)}")
    if pad_h < 0 or pad_w < 0:
        raise ValueError(f"padding must be non-negative, got {(pad_h, pad_w)}")

    # The Hopper/Blackwell TMA path requires outer strides to be 16-byte aligned.
    # For NHWC/OHWI bf16 tensors this means padding the channel dimension for
    # narrow inputs such as RGB (Ci=3).
    input_tensor, weight_tensor = maybe_pad_ci_for_tma(input_tensor, weight_tensor)
    N, H, W, Ci = input_tensor.shape
    Co, R, S, Ci_w = weight_tensor.shape

    out_h = (H + 2 * pad_h - R) // stride_h + 1
    out_w = (W + 2 * pad_w - S) // stride_w + 1
    if out_h <= 0 or out_w <= 0:
        raise ValueError("Invalid convolution geometry: computed output size "
                         f"({out_h}, {out_w}) from H={H}, W={W}, R={R}, S={S}, "
                         f"stride={(stride_h, stride_w)}, padding={(pad_h, pad_w)}.")

    output = torch.empty((N, out_h, out_w, Co), device=input_tensor.device, dtype=TORCH_GEMM_DTYPE)
    return input_tensor, weight_tensor, output, N, H, W, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w


def _make_conv_fprop_descriptors(input_tensor, weight_tensor, out_h, out_w, stride_h, stride_w, pad_h, pad_w,
                                 input_block_shape, weight_block_shape):
    # TMA im2col descriptor for input: [N, H, W, Ci] in NHWC
    #
    # The pixel_box defines the access boundary per batch:
    #   Lower = pixel_box_lower_corner + offsets
    #   Upper = [H, W] + pixel_box_upper_corner + offsets
    # With element_strides = [1, stride_h, stride_w, 1], TMA steps by the
    # per-dimension convolution stride between output pixels. The window must
    # contain exactly out_h * out_w pixels per batch:
    #   pixels_h = floor((window_h - 1) / stride_h) + 1 = out_h
    #   pixels_w = floor((window_w - 1) / stride_w) + 1 = out_w
    #   => window_h = (out_h - 1) * stride_h + 1
    #   => window_w = (out_w - 1) * stride_w + 1
    _, H, W, _ = input_tensor.shape
    upper_h = (out_h - 1) * stride_h + 1 - H - pad_h
    upper_w = (out_w - 1) * stride_w + 1 - W - pad_w

    input_layout = gl.NVMMASharedLayout.get_default_for(input_block_shape, GL_GEMM_DTYPE)
    in_desc = TensorDescriptorIm2Col(
        base=input_tensor,
        shape=list(input_tensor.shape),
        strides=list(input_tensor.stride()),
        block_shape=input_block_shape,
        layout=input_layout,
        padding="zero",
        element_strides=[1, stride_h, stride_w, 1],
        pixel_box_lower_corner=[-pad_h, -pad_w],
        pixel_box_upper_corner=[upper_h, upper_w],
    )

    # TMA tiled descriptor for weight: (Co, R*S*Ci) = (N_GEMM, K_GEMM)
    Co, R, S, Ci = weight_tensor.shape
    weight_reshaped = weight_tensor.reshape(Co, R * S * Ci)
    weight_layout = gl.NVMMASharedLayout.get_default_for(weight_block_shape, GL_GEMM_DTYPE)
    weight_desc = TensorDescriptor.from_tensor(weight_reshaped, weight_block_shape, weight_layout)
    return in_desc, weight_desc


def _make_grid(num_sms, M_GEMM, N_GEMM):

    def grid(meta):
        num_tiles = triton.cdiv(M_GEMM, meta["BLOCK_M"]) * triton.cdiv(N_GEMM, meta["BLOCK_N"])
        return (min(num_sms, num_tiles), )

    return grid


def _launch_conv(
    kernel,
    grid,
    *,
    in_desc,
    weight_desc,
    output,
    N,
    H,
    W,
    Ci,
    Co,
    R,
    S,
    out_h,
    out_w,
    stride_h,
    stride_w,
    pad_h,
    pad_w,
    kernel_meta=None,
):
    if kernel_meta is None:
        kernel_meta = {}

    kernel[grid](
        in_desc,
        weight_desc,
        output,
        N,
        H,
        W,
        Ci,
        Co,
        R,
        S,
        out_h,
        out_w,
        output.stride(0),
        output.stride(1),
        output.stride(2),
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        **kernel_meta,
    )


def conv2d_fprop(input_tensor, weight_tensor, stride=1, padding=0, **kwargs):
    """Production fprop entrypoint.

    Selects the best kernel configuration with Triton autotuning for the given
    convolution shape.
    """
    input_tensor, weight_tensor, output, N, H, W, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w = \
        _prepare_conv_fprop_inputs(input_tensor, weight_tensor, stride, padding)

    M_GEMM = N * out_h * out_w
    N_GEMM = Co
    num_sms = torch.cuda.get_device_properties(input_tensor.device).multi_processor_count

    dummy_block_shape = [1, 1]
    in_desc, weight_desc = _make_conv_fprop_descriptors(
        input_tensor,
        weight_tensor,
        out_h,
        out_w,
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        dummy_block_shape,
        dummy_block_shape,
    )

    _launch_conv(
        conv2d_fprop_autotuned_kernel,
        _make_grid(num_sms, M_GEMM, N_GEMM),
        in_desc=in_desc,
        weight_desc=weight_desc,
        output=output,
        N=N,
        H=H,
        W=W,
        Ci=Ci,
        Co=Co,
        R=R,
        S=S,
        out_h=out_h,
        out_w=out_w,
        stride_h=stride_h,
        stride_w=stride_w,
        pad_h=pad_h,
        pad_w=pad_w,
    )

    return output


conv2d_fprop_persistent = conv2d_fprop


def _make_conv2d_fprop_fixed_kernel_meta(num_buffers, num_warps):
    # Keep the fixed path on a tile shape that is also covered by autotune configs.
    return {
        "BLOCK_M": 128,
        "BLOCK_N": 128,
        "BLOCK_K": 64,
        "GROUP_SIZE_M": 4,
        "num_buffers": num_buffers,
        "num_acc_buffers": 2,
        "num_warps": num_warps,
    }


def conv2d_fprop_fixed(input_tensor, weight_tensor, stride=1, padding=0, num_buffers=3, num_warps=4):
    """Fixed-config fprop entrypoint used for CI and debugging.

    Runs the kernel with a fixed supported tile shape instead of autotuning.
    """
    input_tensor, weight_tensor, output, N, H, W, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w = \
        _prepare_conv_fprop_inputs(input_tensor, weight_tensor, stride, padding)

    kernel_meta = _make_conv2d_fprop_fixed_kernel_meta(num_buffers, num_warps)
    BLOCK_M = kernel_meta["BLOCK_M"]
    BLOCK_N = kernel_meta["BLOCK_N"]
    BLOCK_K = kernel_meta["BLOCK_K"]

    M_GEMM = N * out_h * out_w
    N_GEMM = Co
    num_sms = torch.cuda.get_device_properties(input_tensor.device).multi_processor_count

    in_desc, weight_desc = _make_conv_fprop_descriptors(
        input_tensor,
        weight_tensor,
        out_h,
        out_w,
        stride_h,
        stride_w,
        pad_h,
        pad_w,
        [BLOCK_M, BLOCK_K],
        [BLOCK_N, BLOCK_K],
    )

    _launch_conv(
        conv2d_fprop_kernel,
        _make_grid(num_sms, M_GEMM, N_GEMM),
        in_desc=in_desc,
        weight_desc=weight_desc,
        output=output,
        N=N,
        H=H,
        W=W,
        Ci=Ci,
        Co=Co,
        R=R,
        S=S,
        out_h=out_h,
        out_w=out_w,
        stride_h=stride_h,
        stride_w=stride_w,
        pad_h=pad_h,
        pad_w=pad_w,
        kernel_meta=kernel_meta,
    )

    return output


# ===-----------------------------------------------------------------------===#
# Unit Tests
# ===-----------------------------------------------------------------------===#


def _assert_conv_fprop_correct(fprop_fn, N, Ci, H, W, Co, R, S, stride, padding, **kwargs):
    """Run fprop_fn on NHWC tensors and compare against torch.nn.functional.conv2d."""
    torch.manual_seed(0)
    x_nchw = torch.randn((N, Ci, H, W), device="cuda", dtype=TORCH_GEMM_DTYPE)
    x_nhwc = x_nchw.permute(0, 2, 3, 1).contiguous()
    w_nchw = torch.randn((Co, Ci, R, S), device="cuda", dtype=TORCH_GEMM_DTYPE)
    w_nhwc = w_nchw.permute(0, 2, 3, 1).contiguous()

    triton_out = fprop_fn(x_nhwc, w_nhwc, stride=stride, padding=padding, **kwargs)
    torch_out = torch.nn.functional.conv2d(x_nchw, w_nchw, stride=stride, padding=padding)
    torch_out = torch_out.permute(0, 2, 3, 1)
    torch.testing.assert_close(triton_out, torch_out, atol=5e-2, rtol=5e-2)


@pytest.mark.parametrize("fprop_fn,N,Ci,H,W,Co,R,S,stride,padding", [
    *[(conv2d_fprop_fixed, N, Ci, 64, 64, Co, R, S, stride, padding)
      for N in (1, 128)
      for Ci, Co in ((384, 384), (416, 416))
      for R, S in ((3, 3), (4, 4), (5, 5))
      for stride in (1, 2)
      for padding in (0, 1)], (conv2d_fprop_fixed, 1, 96, 1, 8, 128, 1, 2, (1, 2), 0),  # asymmetric stride
    (conv2d_fprop_fixed, 16, 5, 32, 32, 96, 3, 3, 1, 1),  # padded channels
])
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell GPU (SM 10.x)")
def test_op(fprop_fn, N, Ci, H, W, Co, R, S, stride, padding):
    _assert_conv_fprop_correct(fprop_fn, N, Ci, H, W, Co, R, S, stride, padding)


# ===-----------------------------------------------------------------------===#
# Benchmarking
# ===-----------------------------------------------------------------------===#

BATCH = [128]
CHANNELS = [(384, 384)]
SPATIAL = [(64, 64)]
FILTER = [(3, 3)]
STRIDE = [1]
PADDING = [1]


def _make_bench_inputs(N, H, W, Ci, Co, R, S):
    torch.manual_seed(0)

    x_nchw = torch.randn((N, Ci, H, W), device="cuda", dtype=TORCH_GEMM_DTYPE)
    x_nhwc = x_nchw.permute(0, 2, 3, 1).contiguous()
    w_nchw = torch.randn((Co, Ci, R, S), device="cuda", dtype=TORCH_GEMM_DTYPE)
    w_nhwc = w_nchw.permute(0, 2, 3, 1).contiguous()
    return x_nchw, x_nhwc, w_nchw, w_nhwc


def _benchmark_tflops(fn, *, N, H, W, Ci, Co, R, S, stride_val, pad_val):
    ms = triton.testing.do_bench(fn)
    out_h = (H + 2 * pad_val - R) // stride_val + 1
    out_w = (W + 2 * pad_val - S) // stride_val + 1
    flops = 2.0 * N * out_h * out_w * Co * Ci * R * S
    return flops * 1e-12 / (ms * 1e-3)


bench_configs = []
for N, (Ci, Co), (H, W), (R, S), stride_val, pad_val in [(N, ch, sp, f, s, p)
                                                         for N in BATCH
                                                         for ch in CHANNELS
                                                         for sp in SPATIAL
                                                         for f in FILTER
                                                         for s in STRIDE
                                                         for p in PADDING]:
    bench_configs.append(
        triton.testing.Benchmark(
            x_names=["kernel"],
            x_vals=["autotuned"],
            line_arg="provider",
            line_vals=["gluon", "torch"],
            line_names=["Gluon (autotuned)", "PyTorch"],
            styles=[("green", "-"), ("blue", "-")],
            ylabel="TFLOPS",
            plot_name=f"Conv2d N={N} Ci={Ci} Co={Co} H={H} W={W} R={R} S={S} stride={stride_val} pad={pad_val}",
            args={
                "N": N,
                "H": H,
                "W": W,
                "Ci": Ci,
                "Co": Co,
                "R": R,
                "S": S,
                "stride_val": stride_val,
                "pad_val": pad_val,
            },
        ))


@triton.testing.perf_report(bench_configs)
def bench(N, H, W, Ci, Co, R, S, stride_val, pad_val, kernel, provider):
    x_nchw, x_nhwc, w_nchw, w_nhwc = _make_bench_inputs(N, H, W, Ci, Co, R, S)

    if provider == "gluon":
        fn = lambda: conv2d_fprop(x_nhwc, w_nhwc, stride=stride_val, padding=pad_val)
    elif provider == "torch":
        fn = lambda: torch.nn.functional.conv2d(x_nchw, w_nchw, stride=stride_val, padding=pad_val)
    else:
        raise ValueError(f"Unsupported provider: {provider}")

    return _benchmark_tflops(
        fn,
        N=N,
        H=H,
        W=W,
        Ci=Ci,
        Co=Co,
        R=R,
        S=S,
        stride_val=stride_val,
        pad_val=pad_val,
    )


if __name__ == "__main__":
    bench.run(save_path=".", print_data=True)