Conv Fprop
This example can be found at python/examples/gluon/02-conv-fprop.py.
import importlib.util
import sys
from pathlib import Path
import pytest
import torch
import triton
import triton.language as tl
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor, TensorDescriptorIm2Col
from triton.experimental.gluon.language.nvidia.hopper import tma, mbarrier
from triton.experimental.gluon.language.nvidia.blackwell import (
TensorMemoryLayout,
allocate_tensor_memory,
tensor_memory_descriptor,
tcgen05_mma,
tcgen05_commit,
)
def _load_conv_common():
module_name = "triton_examples_gluon_conv_common"
module = sys.modules.get(module_name)
if module is not None:
return module
module_path = Path(__file__).with_name("02-conv-common.py")
spec = importlib.util.spec_from_file_location(module_name, module_path)
if spec is None or spec.loader is None:
raise ImportError(f"Unable to load shared conv helpers from {module_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
_conv_common = _load_conv_common()
# ===-----------------------------------------------------------------------===#
# Utilities
# ===-----------------------------------------------------------------------===#
Counter = _conv_common.Counter
GL_GEMM_DTYPE = _conv_common.GL_GEMM_DTYPE
PersistentTileScheduler = _conv_common.PersistentTileScheduler
TORCH_GEMM_DTYPE = _conv_common.TORCH_GEMM_DTYPE
init_mbarrier_ring = _conv_common.init_mbarrier_ring
invalidate_mbarrier_ring = _conv_common.invalidate_mbarrier_ring
is_blackwell = _conv_common.is_blackwell
is_cuda = _conv_common.is_cuda
maybe_pad_ci_for_tma = _conv_common.maybe_pad_channel_dims_for_tma
normalize_2d = _conv_common.normalize_2d
# ===-----------------------------------------------------------------------===#
# Convolution Configuration
# ===-----------------------------------------------------------------------===#
# Convolution parameter naming convention:
# N = batch size
# H,W = input spatial dims
# Ci = input channels (part of GEMM-K reduction: K_GEMM = R * S * Ci)
# Co = output channels (maps to GEMM-N dimension: N_GEMM = Co)
# R,S = filter height, width
#
# GEMM mapping:
# M_GEMM = N * out_h * out_w (output spatial positions)
# N_GEMM = Co (output channels)
# K_GEMM = R * S * Ci (reduction over filter x input channels)
@gluon.aggregate
class ConvConfig:
N: gl.tensor
H: gl.tensor
W: gl.tensor
Ci: gl.tensor
Co: gl.tensor
R: gl.tensor
S: gl.tensor
out_h: gl.tensor
out_w: gl.tensor
stride_h: gl.tensor
stride_w: gl.tensor
pad_h: gl.tensor
pad_w: gl.tensor
output_stride_n: gl.tensor
output_stride_h: gl.tensor
output_stride_w: gl.tensor
M_GEMM: gl.tensor
BLOCK_M: gl.constexpr
BLOCK_N: gl.constexpr
BLOCK_K: gl.constexpr
GROUP_SIZE_M: gl.constexpr
num_buffers: gl.constexpr
num_warps: gl.constexpr
@gluon.jit
def get_program(self, pid):
"""Compute tile coordinates from program ID with grouped ordering."""
M_GEMM = self.M_GEMM
N_GEMM = self.Co
num_pid_m = gl.cdiv(M_GEMM, self.BLOCK_M)
num_pid_n = gl.cdiv(N_GEMM, self.BLOCK_N)
num_pid_in_group = self.GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * self.GROUP_SIZE_M
group_size_m = gl.minimum(num_pid_m - first_pid_m, self.GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
return ConvProgram(self, pid_m, pid_n)
@gluon.jit
def get_num_tiles(self):
return gl.cdiv(self.M_GEMM, self.BLOCK_M) * gl.cdiv(self.Co, self.BLOCK_N)
@gluon.jit
def get_num_k_iterations(self):
return self.R * self.S * gl.cdiv(self.Ci, self.BLOCK_K)
@gluon.aggregate
class ConvProgram:
config: ConvConfig
pid_m: gl.tensor
pid_n: gl.tensor
@gluon.jit
def get_m_offsets(self):
"""Decompose M-tile offset into (batch, out_y, out_x)."""
offs_m = self.pid_m * self.config.BLOCK_M
config = self.config
out_x = offs_m % config.out_w
out_y = (offs_m // config.out_w) % config.out_h
batch_id = (offs_m // config.out_w) // config.out_h
return batch_id, out_y, out_x
# ===-----------------------------------------------------------------------===#
# Partition Arguments
# ===-----------------------------------------------------------------------===#
@gluon.aggregate
class PartitionArgs:
config: ConvConfig
in_desc: tma.tensor_descriptor_im2col
weight_desc: tma.tensor_descriptor
output_ptr: gl.tensor
a_bufs: gl.shared_memory_descriptor
b_bufs: gl.shared_memory_descriptor
load_empty_bars: gl.shared_memory_descriptor
load_ready_bars: gl.shared_memory_descriptor
acc_bufs: tensor_memory_descriptor
acc_empty_bars: gl.shared_memory_descriptor
acc_ready_bars: gl.shared_memory_descriptor
# ===-----------------------------------------------------------------------===#
# Warp-Specialized Partitions
# ===-----------------------------------------------------------------------===#
@gluon.jit
def load_partition(p):
"""Load partition: iterate over this CTA's assigned output tiles."""
config = p.config
BLOCK_K: gl.constexpr = config.BLOCK_K
empty_bars = p.load_empty_bars
ready_bars = p.load_ready_bars
state = Counter.create(1, empty_bars.shape[0])
num_rs = config.R * config.S
num_k_iter = config.get_num_k_iterations()
scheduler = PersistentTileScheduler.initialize(config.get_num_tiles())
for idx in range(scheduler.get_num_tiles()):
prog = config.get_program(scheduler.get_tile_id(idx))
batch_id, out_y, out_x = prog.get_m_offsets()
for k_iter in range(num_k_iter):
iter_ci = k_iter // num_rs
remain_rs = k_iter % num_rs
iter_s = remain_rs % config.S
iter_r = remain_rs // config.S
ready_bar = ready_bars.index(state.index)
mbarrier.wait(empty_bars.index(state.index), state.phase)
mbarrier.expect(ready_bar, p.in_desc.block_type.nbytes + p.weight_desc.block_type.nbytes)
tma.async_load_im2col(
p.in_desc,
[
batch_id,
out_y * config.stride_h - config.pad_h,
out_x * config.stride_w - config.pad_w,
iter_ci * BLOCK_K,
],
[iter_r.to(tl.int16), iter_s.to(tl.int16)],
ready_bar,
p.a_bufs.index(state.index),
)
k_offset = (iter_r * config.S + iter_s) * config.Ci + iter_ci * BLOCK_K
tma.async_load(
p.weight_desc,
[prog.pid_n * config.BLOCK_N, k_offset],
ready_bar,
p.b_bufs.index(state.index),
)
state = state.next()
@gluon.jit
def mma_partition(p):
"""MMA partition: accumulate over all tiles assigned to this CTA."""
config = p.config
num_k_iter = config.get_num_k_iterations()
load_state = Counter.create(0, p.load_empty_bars.shape[0])
acc_state = Counter.create(1, p.acc_empty_bars.shape[0])
scheduler = PersistentTileScheduler.initialize(config.get_num_tiles())
for _ in range(scheduler.get_num_tiles()):
mbarrier.wait(p.acc_empty_bars.index(acc_state.index), acc_state.phase)
acc_buf = p.acc_bufs.index(acc_state.index)
use_acc = False
for _k_iter in range(num_k_iter):
mbarrier.wait(p.load_ready_bars.index(load_state.index), load_state.phase)
tcgen05_mma(
p.a_bufs.index(load_state.index),
p.b_bufs.index(load_state.index).permute((1, 0)),
acc_buf,
use_acc=use_acc,
)
tcgen05_commit(p.load_empty_bars.index(load_state.index))
load_state = load_state.next()
use_acc = True
tcgen05_commit(p.acc_ready_bars.index(acc_state.index))
acc_state = acc_state.next()
@gluon.jit
def epilogue_partition(p):
"""Epilogue partition: store all tiles assigned to this CTA."""
config = p.config
BLOCK_M: gl.constexpr = config.BLOCK_M
BLOCK_N: gl.constexpr = config.BLOCK_N
M_GEMM = config.M_GEMM
N_GEMM = config.Co
acc_state = Counter.create(0, p.acc_empty_bars.shape[0])
scheduler = PersistentTileScheduler.initialize(config.get_num_tiles())
for idx in range(scheduler.get_num_tiles()):
prog = config.get_program(scheduler.get_tile_id(idx))
mbarrier.wait(p.acc_ready_bars.index(acc_state.index), acc_state.phase)
acc = p.acc_bufs.index(acc_state.index).load()
result = gl.convert_layout(acc.to(GL_GEMM_DTYPE), gl.CoalescedLayout())
mbarrier.arrive(p.acc_empty_bars.index(acc_state.index), count=1)
acc_state = acc_state.next()
offs_m = prog.pid_m * BLOCK_M + gl.arange(0, BLOCK_M)
offs_n = prog.pid_n * BLOCK_N + gl.arange(0, BLOCK_N)
c_out_x = offs_m % config.out_w
c_out_y = (offs_m // config.out_w) % config.out_h
c_batch = (offs_m // config.out_w) // config.out_h
c_offsets = (c_batch[:, None] * config.output_stride_n + c_out_y[:, None] * config.output_stride_h +
c_out_x[:, None] * config.output_stride_w + offs_n[None, :])
c_mask = (offs_m[:, None] < M_GEMM) & (offs_n[None, :] < N_GEMM)
gl.store(p.output_ptr + c_offsets, result, mask=c_mask)
invalidate_mbarrier_ring(p.load_empty_bars)
invalidate_mbarrier_ring(p.load_ready_bars)
invalidate_mbarrier_ring(p.acc_empty_bars)
invalidate_mbarrier_ring(p.acc_ready_bars)
# ===-----------------------------------------------------------------------===#
# Kernel Entry Point
# ===-----------------------------------------------------------------------===#
@gluon.jit(do_not_specialize=[
"N",
"H",
"W",
"R",
"S",
"pad_h",
"pad_w",
])
def conv2d_fprop_kernel(
in_desc,
weight_desc,
output,
N,
H,
W,
Ci,
Co,
R,
S,
out_h,
out_w,
output_stride_n,
output_stride_h,
output_stride_w,
stride_h,
stride_w,
pad_h,
pad_w,
BLOCK_M: gl.constexpr,
BLOCK_N: gl.constexpr,
BLOCK_K: gl.constexpr,
GROUP_SIZE_M: gl.constexpr,
num_buffers: gl.constexpr,
num_acc_buffers: gl.constexpr,
num_warps: gl.constexpr,
):
"""Warp-specialized forward convolution kernel."""
M_GEMM = N * out_h * out_w
config = ConvConfig(
N,
H,
W,
Ci,
Co,
R,
S,
gl.to_tensor(out_h),
gl.to_tensor(out_w),
gl.to_tensor(stride_h),
gl.to_tensor(stride_w),
pad_h,
pad_w,
gl.to_tensor(output_stride_n),
gl.to_tensor(output_stride_h),
gl.to_tensor(output_stride_w),
M_GEMM,
BLOCK_M,
BLOCK_N,
BLOCK_K,
GROUP_SIZE_M,
num_buffers,
num_warps,
)
a_smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK_M, BLOCK_K], GL_GEMM_DTYPE)
b_smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK_N, BLOCK_K], GL_GEMM_DTYPE)
a_bufs = gl.allocate_shared_memory(GL_GEMM_DTYPE, [num_buffers, BLOCK_M, BLOCK_K], a_smem_layout)
b_bufs = gl.allocate_shared_memory(GL_GEMM_DTYPE, [num_buffers, BLOCK_N, BLOCK_K], b_smem_layout)
load_empty_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
load_ready_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout())
init_mbarrier_ring(load_empty_bars)
init_mbarrier_ring(load_ready_bars)
TMEM_BLOCK_M: gl.constexpr = 64 if BLOCK_M == 64 else 128
tmem_layout: gl.constexpr = TensorMemoryLayout(block=(TMEM_BLOCK_M, BLOCK_N), col_stride=1)
# Smaller tiles can profit from a double-buffered accumulator ring, but
# large 256x256 tiles exceed Blackwell's TMEM budget unless the ring depth
# is reduced to 1.
acc_bufs = allocate_tensor_memory(gl.float32, [num_acc_buffers, BLOCK_M, BLOCK_N], tmem_layout)
acc_empty_bars = gl.allocate_shared_memory(gl.int64, [num_acc_buffers, 1], mbarrier.MBarrierLayout())
acc_ready_bars = gl.allocate_shared_memory(gl.int64, [num_acc_buffers, 1], mbarrier.MBarrierLayout())
init_mbarrier_ring(acc_empty_bars)
init_mbarrier_ring(acc_ready_bars)
p = PartitionArgs(
config,
in_desc,
weight_desc,
output,
a_bufs,
b_bufs,
load_empty_bars,
load_ready_bars,
acc_bufs,
acc_empty_bars,
acc_ready_bars,
)
gl.warp_specialize([
(epilogue_partition, (p, )),
(mma_partition, (p, )),
(load_partition, (p, )),
], [1, 1], [24, 24])
def conv2d_fprop_get_configs(pre_hook=None):
return [
triton.Config(
{
"BLOCK_M": block_m,
"BLOCK_N": block_n,
"BLOCK_K": block_k,
"GROUP_SIZE_M": group_size_m,
"num_buffers": num_buffers,
"num_acc_buffers": num_acc_buffers,
},
num_warps=num_warps,
pre_hook=pre_hook,
)
for block_m in (64, 128)
for block_n in (8, 32, 128, 256)
for block_k in (64, )
for group_size_m in (4, )
for num_buffers in (3, 4, 5)
for num_acc_buffers in (2, )
for num_warps in (4, )
]
def conv2d_fprop_tma_set_block_size_hook(nargs):
in_block_shape = [nargs["BLOCK_M"], nargs["BLOCK_K"]]
weight_block_shape = [nargs["BLOCK_N"], nargs["BLOCK_K"]]
nargs["in_desc"].block_shape = in_block_shape
nargs["in_desc"].layout = gl.NVMMASharedLayout.get_default_for(in_block_shape, GL_GEMM_DTYPE)
nargs["weight_desc"].block_shape = weight_block_shape
nargs["weight_desc"].layout = gl.NVMMASharedLayout.get_default_for(weight_block_shape, GL_GEMM_DTYPE)
# Key on the effective implicit-GEMM/convolution geometry instead of the full
# raw input shape. `out_h/out_w` already encode the impact of H/W/padding on
# the launch shape, so keeping all of them would only fragment the autotune
# cache without exposing meaningfully different tile choices.
conv2d_fprop_autotuned_kernel = triton.autotune(
configs=conv2d_fprop_get_configs(pre_hook=conv2d_fprop_tma_set_block_size_hook),
key=["out_h", "out_w", "stride_h", "stride_w"],
)(conv2d_fprop_kernel)
# ===-----------------------------------------------------------------------===#
# Host-Side Entry Point
# ===-----------------------------------------------------------------------===#
def _prepare_conv_fprop_inputs(input_tensor, weight_tensor, stride, padding):
N, H, W, Ci = input_tensor.shape
Co, R, S, Ci_w = weight_tensor.shape
assert Ci == Ci_w, "Input and weight channel dimensions must match"
if input_tensor.dtype != TORCH_GEMM_DTYPE or weight_tensor.dtype != TORCH_GEMM_DTYPE:
raise ValueError(
f"conv2d_fprop expects bfloat16 input/weight tensors, got {input_tensor.dtype} and {weight_tensor.dtype}")
stride_h, stride_w = normalize_2d(stride, "stride")
pad_h, pad_w = normalize_2d(padding, "padding")
if stride_h <= 0 or stride_w <= 0:
raise ValueError(f"stride must be positive, got {(stride_h, stride_w)}")
if pad_h < 0 or pad_w < 0:
raise ValueError(f"padding must be non-negative, got {(pad_h, pad_w)}")
# The Hopper/Blackwell TMA path requires outer strides to be 16-byte aligned.
# For NHWC/OHWI bf16 tensors this means padding the channel dimension for
# narrow inputs such as RGB (Ci=3).
input_tensor, weight_tensor = maybe_pad_ci_for_tma(input_tensor, weight_tensor)
N, H, W, Ci = input_tensor.shape
Co, R, S, Ci_w = weight_tensor.shape
out_h = (H + 2 * pad_h - R) // stride_h + 1
out_w = (W + 2 * pad_w - S) // stride_w + 1
if out_h <= 0 or out_w <= 0:
raise ValueError("Invalid convolution geometry: computed output size "
f"({out_h}, {out_w}) from H={H}, W={W}, R={R}, S={S}, "
f"stride={(stride_h, stride_w)}, padding={(pad_h, pad_w)}.")
output = torch.empty((N, out_h, out_w, Co), device=input_tensor.device, dtype=TORCH_GEMM_DTYPE)
return input_tensor, weight_tensor, output, N, H, W, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w
def _make_conv_fprop_descriptors(input_tensor, weight_tensor, out_h, out_w, stride_h, stride_w, pad_h, pad_w,
input_block_shape, weight_block_shape):
# TMA im2col descriptor for input: [N, H, W, Ci] in NHWC
#
# The pixel_box defines the access boundary per batch:
# Lower = pixel_box_lower_corner + offsets
# Upper = [H, W] + pixel_box_upper_corner + offsets
# With element_strides = [1, stride_h, stride_w, 1], TMA steps by the
# per-dimension convolution stride between output pixels. The window must
# contain exactly out_h * out_w pixels per batch:
# pixels_h = floor((window_h - 1) / stride_h) + 1 = out_h
# pixels_w = floor((window_w - 1) / stride_w) + 1 = out_w
# => window_h = (out_h - 1) * stride_h + 1
# => window_w = (out_w - 1) * stride_w + 1
_, H, W, _ = input_tensor.shape
upper_h = (out_h - 1) * stride_h + 1 - H - pad_h
upper_w = (out_w - 1) * stride_w + 1 - W - pad_w
input_layout = gl.NVMMASharedLayout.get_default_for(input_block_shape, GL_GEMM_DTYPE)
in_desc = TensorDescriptorIm2Col(
base=input_tensor,
shape=list(input_tensor.shape),
strides=list(input_tensor.stride()),
block_shape=input_block_shape,
layout=input_layout,
padding="zero",
element_strides=[1, stride_h, stride_w, 1],
pixel_box_lower_corner=[-pad_h, -pad_w],
pixel_box_upper_corner=[upper_h, upper_w],
)
# TMA tiled descriptor for weight: (Co, R*S*Ci) = (N_GEMM, K_GEMM)
Co, R, S, Ci = weight_tensor.shape
weight_reshaped = weight_tensor.reshape(Co, R * S * Ci)
weight_layout = gl.NVMMASharedLayout.get_default_for(weight_block_shape, GL_GEMM_DTYPE)
weight_desc = TensorDescriptor.from_tensor(weight_reshaped, weight_block_shape, weight_layout)
return in_desc, weight_desc
def _make_grid(num_sms, M_GEMM, N_GEMM):
def grid(meta):
num_tiles = triton.cdiv(M_GEMM, meta["BLOCK_M"]) * triton.cdiv(N_GEMM, meta["BLOCK_N"])
return (min(num_sms, num_tiles), )
return grid
def _launch_conv(
kernel,
grid,
*,
in_desc,
weight_desc,
output,
N,
H,
W,
Ci,
Co,
R,
S,
out_h,
out_w,
stride_h,
stride_w,
pad_h,
pad_w,
kernel_meta=None,
):
if kernel_meta is None:
kernel_meta = {}
kernel[grid](
in_desc,
weight_desc,
output,
N,
H,
W,
Ci,
Co,
R,
S,
out_h,
out_w,
output.stride(0),
output.stride(1),
output.stride(2),
stride_h,
stride_w,
pad_h,
pad_w,
**kernel_meta,
)
def conv2d_fprop(input_tensor, weight_tensor, stride=1, padding=0, **kwargs):
"""Production fprop entrypoint.
Selects the best kernel configuration with Triton autotuning for the given
convolution shape.
"""
input_tensor, weight_tensor, output, N, H, W, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w = \
_prepare_conv_fprop_inputs(input_tensor, weight_tensor, stride, padding)
M_GEMM = N * out_h * out_w
N_GEMM = Co
num_sms = torch.cuda.get_device_properties(input_tensor.device).multi_processor_count
dummy_block_shape = [1, 1]
in_desc, weight_desc = _make_conv_fprop_descriptors(
input_tensor,
weight_tensor,
out_h,
out_w,
stride_h,
stride_w,
pad_h,
pad_w,
dummy_block_shape,
dummy_block_shape,
)
_launch_conv(
conv2d_fprop_autotuned_kernel,
_make_grid(num_sms, M_GEMM, N_GEMM),
in_desc=in_desc,
weight_desc=weight_desc,
output=output,
N=N,
H=H,
W=W,
Ci=Ci,
Co=Co,
R=R,
S=S,
out_h=out_h,
out_w=out_w,
stride_h=stride_h,
stride_w=stride_w,
pad_h=pad_h,
pad_w=pad_w,
)
return output
conv2d_fprop_persistent = conv2d_fprop
def _make_conv2d_fprop_fixed_kernel_meta(num_buffers, num_warps):
# Keep the fixed path on a tile shape that is also covered by autotune configs.
return {
"BLOCK_M": 128,
"BLOCK_N": 128,
"BLOCK_K": 64,
"GROUP_SIZE_M": 4,
"num_buffers": num_buffers,
"num_acc_buffers": 2,
"num_warps": num_warps,
}
def conv2d_fprop_fixed(input_tensor, weight_tensor, stride=1, padding=0, num_buffers=3, num_warps=4):
"""Fixed-config fprop entrypoint used for CI and debugging.
Runs the kernel with a fixed supported tile shape instead of autotuning.
"""
input_tensor, weight_tensor, output, N, H, W, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w = \
_prepare_conv_fprop_inputs(input_tensor, weight_tensor, stride, padding)
kernel_meta = _make_conv2d_fprop_fixed_kernel_meta(num_buffers, num_warps)
BLOCK_M = kernel_meta["BLOCK_M"]
BLOCK_N = kernel_meta["BLOCK_N"]
BLOCK_K = kernel_meta["BLOCK_K"]
M_GEMM = N * out_h * out_w
N_GEMM = Co
num_sms = torch.cuda.get_device_properties(input_tensor.device).multi_processor_count
in_desc, weight_desc = _make_conv_fprop_descriptors(
input_tensor,
weight_tensor,
out_h,
out_w,
stride_h,
stride_w,
pad_h,
pad_w,
[BLOCK_M, BLOCK_K],
[BLOCK_N, BLOCK_K],
)
_launch_conv(
conv2d_fprop_kernel,
_make_grid(num_sms, M_GEMM, N_GEMM),
in_desc=in_desc,
weight_desc=weight_desc,
output=output,
N=N,
H=H,
W=W,
Ci=Ci,
Co=Co,
R=R,
S=S,
out_h=out_h,
out_w=out_w,
stride_h=stride_h,
stride_w=stride_w,
pad_h=pad_h,
pad_w=pad_w,
kernel_meta=kernel_meta,
)
return output
# ===-----------------------------------------------------------------------===#
# Unit Tests
# ===-----------------------------------------------------------------------===#
def _assert_conv_fprop_correct(fprop_fn, N, Ci, H, W, Co, R, S, stride, padding, **kwargs):
"""Run fprop_fn on NHWC tensors and compare against torch.nn.functional.conv2d."""
torch.manual_seed(0)
x_nchw = torch.randn((N, Ci, H, W), device="cuda", dtype=TORCH_GEMM_DTYPE)
x_nhwc = x_nchw.permute(0, 2, 3, 1).contiguous()
w_nchw = torch.randn((Co, Ci, R, S), device="cuda", dtype=TORCH_GEMM_DTYPE)
w_nhwc = w_nchw.permute(0, 2, 3, 1).contiguous()
triton_out = fprop_fn(x_nhwc, w_nhwc, stride=stride, padding=padding, **kwargs)
torch_out = torch.nn.functional.conv2d(x_nchw, w_nchw, stride=stride, padding=padding)
torch_out = torch_out.permute(0, 2, 3, 1)
torch.testing.assert_close(triton_out, torch_out, atol=5e-2, rtol=5e-2)
@pytest.mark.parametrize("fprop_fn,N,Ci,H,W,Co,R,S,stride,padding", [
*[(conv2d_fprop_fixed, N, Ci, 64, 64, Co, R, S, stride, padding)
for N in (1, 128)
for Ci, Co in ((384, 384), (416, 416))
for R, S in ((3, 3), (4, 4), (5, 5))
for stride in (1, 2)
for padding in (0, 1)], (conv2d_fprop_fixed, 1, 96, 1, 8, 128, 1, 2, (1, 2), 0), # asymmetric stride
(conv2d_fprop_fixed, 16, 5, 32, 32, 96, 3, 3, 1, 1), # padded channels
])
@pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell GPU (SM 10.x)")
def test_op(fprop_fn, N, Ci, H, W, Co, R, S, stride, padding):
_assert_conv_fprop_correct(fprop_fn, N, Ci, H, W, Co, R, S, stride, padding)
# ===-----------------------------------------------------------------------===#
# Benchmarking
# ===-----------------------------------------------------------------------===#
BATCH = [128]
CHANNELS = [(384, 384)]
SPATIAL = [(64, 64)]
FILTER = [(3, 3)]
STRIDE = [1]
PADDING = [1]
def _make_bench_inputs(N, H, W, Ci, Co, R, S):
torch.manual_seed(0)
x_nchw = torch.randn((N, Ci, H, W), device="cuda", dtype=TORCH_GEMM_DTYPE)
x_nhwc = x_nchw.permute(0, 2, 3, 1).contiguous()
w_nchw = torch.randn((Co, Ci, R, S), device="cuda", dtype=TORCH_GEMM_DTYPE)
w_nhwc = w_nchw.permute(0, 2, 3, 1).contiguous()
return x_nchw, x_nhwc, w_nchw, w_nhwc
def _benchmark_tflops(fn, *, N, H, W, Ci, Co, R, S, stride_val, pad_val):
ms = triton.testing.do_bench(fn)
out_h = (H + 2 * pad_val - R) // stride_val + 1
out_w = (W + 2 * pad_val - S) // stride_val + 1
flops = 2.0 * N * out_h * out_w * Co * Ci * R * S
return flops * 1e-12 / (ms * 1e-3)
bench_configs = []
for N, (Ci, Co), (H, W), (R, S), stride_val, pad_val in [(N, ch, sp, f, s, p)
for N in BATCH
for ch in CHANNELS
for sp in SPATIAL
for f in FILTER
for s in STRIDE
for p in PADDING]:
bench_configs.append(
triton.testing.Benchmark(
x_names=["kernel"],
x_vals=["autotuned"],
line_arg="provider",
line_vals=["gluon", "torch"],
line_names=["Gluon (autotuned)", "PyTorch"],
styles=[("green", "-"), ("blue", "-")],
ylabel="TFLOPS",
plot_name=f"Conv2d N={N} Ci={Ci} Co={Co} H={H} W={W} R={R} S={S} stride={stride_val} pad={pad_val}",
args={
"N": N,
"H": H,
"W": W,
"Ci": Ci,
"Co": Co,
"R": R,
"S": S,
"stride_val": stride_val,
"pad_val": pad_val,
},
))
@triton.testing.perf_report(bench_configs)
def bench(N, H, W, Ci, Co, R, S, stride_val, pad_val, kernel, provider):
x_nchw, x_nhwc, w_nchw, w_nhwc = _make_bench_inputs(N, H, W, Ci, Co, R, S)
if provider == "gluon":
fn = lambda: conv2d_fprop(x_nhwc, w_nhwc, stride=stride_val, padding=pad_val)
elif provider == "torch":
fn = lambda: torch.nn.functional.conv2d(x_nchw, w_nchw, stride=stride_val, padding=pad_val)
else:
raise ValueError(f"Unsupported provider: {provider}")
return _benchmark_tflops(
fn,
N=N,
H=H,
W=W,
Ci=Ci,
Co=Co,
R=R,
S=S,
stride_val=stride_val,
pad_val=pad_val,
)
if __name__ == "__main__":
bench.run(save_path=".", print_data=True)