(example_conv-wgrad)= # Conv Wgrad This example can be found at ``python/examples/gluon/02-conv-wgrad.py``. ```python import importlib.util import sys from pathlib import Path import pytest import torch import triton import triton.language as tl from triton.experimental import gluon from triton.experimental.gluon import language as gl from triton.experimental.gluon.nvidia.hopper import TensorDescriptor, TensorDescriptorIm2Col from triton.experimental.gluon.language.nvidia.hopper import tma, mbarrier from triton.experimental.gluon.language.nvidia.blackwell import ( TensorMemoryLayout, allocate_tensor_memory, tensor_memory_descriptor, tcgen05_mma, tcgen05_commit, ) def _load_conv_common(): module_name = "triton_examples_gluon_conv_common" module = sys.modules.get(module_name) if module is not None: return module module_path = Path(__file__).with_name("02-conv-common.py") spec = importlib.util.spec_from_file_location(module_name, module_path) if spec is None or spec.loader is None: raise ImportError(f"Unable to load shared conv helpers from {module_path}") module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) return module _conv_common = _load_conv_common() # ===-----------------------------------------------------------------------===# # Utilities # ===-----------------------------------------------------------------------===# Counter = _conv_common.Counter GL_GEMM_DTYPE = _conv_common.GL_GEMM_DTYPE PersistentTileScheduler = _conv_common.PersistentTileScheduler TORCH_GEMM_DTYPE = _conv_common.TORCH_GEMM_DTYPE init_mbarrier_ring = _conv_common.init_mbarrier_ring invalidate_mbarrier_ring = _conv_common.invalidate_mbarrier_ring is_blackwell = _conv_common.is_blackwell maybe_pad_ci_for_tma = _conv_common.maybe_pad_channel_dims_for_tma normalize_2d = _conv_common.normalize_2d # ===-----------------------------------------------------------------------===# # Wgrad GEMM mapping # ===-----------------------------------------------------------------------===# # # grad_W[Co, R*S*Ci] = grad_out[M, Co]^T @ im2col(input)[M, R*S*Ci] # # where M = N * out_h * out_w (spatial positions — reduction dimension) # # MMA tiling: # BLOCK_M = tile over Co (rows of grad_weight) # BLOCK_N = tile over Ci per (r,s) (cols of grad_weight) # BLOCK_K = tile over spatial (reduction) # # Logical tile space: cdiv(Co, BLOCK_M) * R * S * cdiv(Ci, BLOCK_N), optionally # multiplied by split-K. The launch uses a persistent scheduler and runs only # `min(num_sms, logical_tiles)` CTAs. # # Loads per K iteration: # A = grad_out tile: TMA tiled on (M_spatial, Co), # block [BLOCK_K, BLOCK_M] — permuted to [M, K] in kernel. # B = im2col(input) tile: TMA im2col on [N,H,W,Ci], block [BLOCK_K, BLOCK_N] # Already [K, N], no kernel permute. # # MMA: acc[BLOCK_M, BLOCK_N] += A.permute(1,0) @ B # ===-----------------------------------------------------------------------===# # Wgrad Configuration # ===-----------------------------------------------------------------------===# @gluon.aggregate class WgradConfig: N: gl.tensor Ci: gl.tensor Co: gl.tensor R: gl.tensor S: gl.tensor out_h: gl.tensor out_w: gl.tensor stride_h: gl.tensor stride_w: gl.tensor pad_h: gl.tensor pad_w: gl.tensor K_GEMM: gl.tensor M_spatial: gl.tensor BLOCK_M: gl.constexpr BLOCK_N: gl.constexpr BLOCK_K: gl.constexpr SPLIT_K: gl.constexpr num_buffers: gl.constexpr num_warps: gl.constexpr @gluon.jit def get_num_output_tiles(self): co_num_blocks = gl.cdiv(self.Co, self.BLOCK_M) ci_num_blocks = gl.cdiv(self.Ci, self.BLOCK_N) return co_num_blocks * self.R * self.S * ci_num_blocks @gluon.jit def get_num_k_iterations(self): return gl.cdiv(self.M_spatial, self.BLOCK_K) @gluon.jit def get_active_split_k(self): total_k_iters = self.get_num_k_iterations() k_iters_per_split = gl.cdiv(total_k_iters, self.SPLIT_K) return gl.cdiv(total_k_iters, k_iters_per_split) @gluon.jit def get_num_tiles(self): return self.get_num_output_tiles() * self.get_active_split_k() @gluon.jit def get_program(self, pid): active_split_k = self.get_active_split_k() split_k_idx = pid % active_split_k tile_id = pid // active_split_k ci_num_blocks = gl.cdiv(self.Ci, self.BLOCK_N) co_num_blocks = gl.cdiv(self.Co, self.BLOCK_M) pid_co = tile_id % co_num_blocks pid_n = tile_id // co_num_blocks ci_block = pid_n % ci_num_blocks rs_idx = pid_n // ci_num_blocks iter_r = rs_idx // self.S iter_s = rs_idx % self.S total_k_iters = self.get_num_k_iterations() k_iters_per_split = gl.cdiv(total_k_iters, active_split_k) k_start = split_k_idx * k_iters_per_split remaining_k_iters = total_k_iters - k_start zero = gl.to_tensor(0) k_iters_this_split = gl.where( remaining_k_iters > 0, gl.minimum(k_iters_per_split, remaining_k_iters), zero, ) return WgradProgram(self, pid_co, ci_block, iter_r, iter_s, split_k_idx, k_start, k_iters_this_split) @gluon.aggregate class WgradProgram: config: WgradConfig pid_co: gl.tensor ci_block: gl.tensor iter_r: gl.tensor iter_s: gl.tensor split_k_idx: gl.tensor k_start: gl.tensor k_iters_this_split: gl.tensor @gluon.jit def get_co_offset(self): return self.pid_co * self.config.BLOCK_M @gluon.jit def get_ci_offset(self): return self.ci_block * self.config.BLOCK_N @gluon.jit def get_spatial_offsets(self, local_k): m_global = (self.k_start + local_k) * self.config.BLOCK_K spatial_per_batch = self.config.out_h * self.config.out_w m_in_batch = m_global % spatial_per_batch batch = m_global // spatial_per_batch out_x = m_in_batch % self.config.out_w out_y = m_in_batch // self.config.out_w return m_global, batch, out_y, out_x @gluon.jit def get_weight_k_offset(self): return (self.iter_r * self.config.S + self.iter_s) * self.config.Ci + self.get_ci_offset() # ===-----------------------------------------------------------------------===# # Partition Arguments # ===-----------------------------------------------------------------------===# @gluon.aggregate class PartitionArgs: config: WgradConfig in_desc: tma.tensor_descriptor_im2col grad_out_desc: tma.tensor_descriptor grad_weight_ptr: gl.tensor grad_weight_stride_0: gl.tensor a_bufs: gl.shared_memory_descriptor b_bufs: gl.shared_memory_descriptor load_empty_bars: gl.shared_memory_descriptor load_ready_bars: gl.shared_memory_descriptor acc_bufs: tensor_memory_descriptor acc_empty_bars: gl.shared_memory_descriptor acc_ready_bars: gl.shared_memory_descriptor # ===-----------------------------------------------------------------------===# # Warp-Specialized Partitions # ===-----------------------------------------------------------------------===# @gluon.jit def load_partition(p): """Load partition: iterate over the persistent wgrad work items assigned to this CTA.""" config = p.config empty_bars = p.load_empty_bars ready_bars = p.load_ready_bars state = Counter.create(1, empty_bars.shape[0]) scheduler = PersistentTileScheduler.initialize(config.get_num_tiles()) for idx in range(scheduler.get_num_tiles()): prog = config.get_program(scheduler.get_tile_id(idx)) co_offset = prog.get_co_offset() ci_offset = prog.get_ci_offset() for local_k in range(prog.k_iters_this_split): m_global, batch, out_y, out_x = prog.get_spatial_offsets(local_k) ready_bar = ready_bars.index(state.index) mbarrier.wait(empty_bars.index(state.index), state.phase) mbarrier.expect(ready_bar, p.grad_out_desc.block_type.nbytes + p.in_desc.block_type.nbytes) # A = grad_output: (M_spatial, Co), block [BLOCK_K, BLOCK_M] tma.async_load( p.grad_out_desc, [m_global, co_offset], ready_bar, p.a_bufs.index(state.index), ) # B = im2col(input): [N, H, W, Ci], block [BLOCK_K, BLOCK_N] tma.async_load_im2col( p.in_desc, [ batch, out_y * config.stride_h - config.pad_h, out_x * config.stride_w - config.pad_w, ci_offset, ], [prog.iter_r.to(tl.int16), prog.iter_s.to(tl.int16)], ready_bar, p.b_bufs.index(state.index), ) state = state.next() @gluon.jit def mma_partition(p): """MMA partition: accumulate all split-K work items assigned to this CTA.""" config = p.config load_state = Counter.create(0, p.load_empty_bars.shape[0]) acc_state = Counter.create(1, p.acc_empty_bars.shape[0]) scheduler = PersistentTileScheduler.initialize(config.get_num_tiles()) for idx in range(scheduler.get_num_tiles()): prog = config.get_program(scheduler.get_tile_id(idx)) mbarrier.wait(p.acc_empty_bars.index(acc_state.index), acc_state.phase) acc_buf = p.acc_bufs.index(acc_state.index) use_acc = False for _local_k in range(prog.k_iters_this_split): mbarrier.wait(p.load_ready_bars.index(load_state.index), load_state.phase) tcgen05_mma( p.a_bufs.index(load_state.index).permute((1, 0)), p.b_bufs.index(load_state.index), acc_buf, use_acc=use_acc, ) tcgen05_commit(p.load_empty_bars.index(load_state.index)) load_state = load_state.next() use_acc = True tcgen05_commit(p.acc_ready_bars.index(acc_state.index)) acc_state = acc_state.next() @gluon.jit def epilogue_partition(p): """Epilogue partition: store the persistent wgrad work items assigned to this CTA.""" config = p.config active_split_k = config.get_active_split_k() BLOCK_M: gl.constexpr = config.BLOCK_M BLOCK_N: gl.constexpr = config.BLOCK_N acc_state = Counter.create(0, p.acc_empty_bars.shape[0]) scheduler = PersistentTileScheduler.initialize(config.get_num_tiles()) for idx in range(scheduler.get_num_tiles()): prog = config.get_program(scheduler.get_tile_id(idx)) co_offset = prog.get_co_offset() ci_offset = prog.get_ci_offset() weight_k_offset = prog.get_weight_k_offset() mbarrier.wait(p.acc_ready_bars.index(acc_state.index), acc_state.phase) acc = p.acc_bufs.index(acc_state.index).load() result = gl.convert_layout(acc, gl.CoalescedLayout()) mbarrier.arrive(p.acc_empty_bars.index(acc_state.index), count=1) acc_state = acc_state.next() split_co_offset = gl.where(active_split_k > 1, prog.split_k_idx * config.Co, gl.to_tensor(0)) offs_m = co_offset + gl.arange(0, BLOCK_M) offs_n = weight_k_offset + gl.arange(0, BLOCK_N) ci_valid = (ci_offset + gl.arange(0, BLOCK_N)) < config.Ci mask = (offs_m[:, None] < config.Co) & (offs_n[None, :] < config.K_GEMM) & ci_valid[None, :] store_rows = split_co_offset + offs_m offsets = store_rows[:, None] * p.grad_weight_stride_0 + offs_n[None, :] gl.store(p.grad_weight_ptr + offsets, result, mask=mask) invalidate_mbarrier_ring(p.load_empty_bars) invalidate_mbarrier_ring(p.load_ready_bars) invalidate_mbarrier_ring(p.acc_empty_bars) invalidate_mbarrier_ring(p.acc_ready_bars) # ===-----------------------------------------------------------------------===# # Kernel Entry Point # ===-----------------------------------------------------------------------===# @gluon.jit(do_not_specialize=[ "N", "R", "S", "out_h", "out_w", "stride_h", "stride_w", "pad_h", "pad_w", ]) def conv2d_wgrad_kernel( in_desc, grad_out_desc, grad_weight, N, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w, K_GEMM, grad_weight_stride_0, BLOCK_M: gl.constexpr, BLOCK_N: gl.constexpr, BLOCK_K: gl.constexpr, SPLIT_K: gl.constexpr, num_buffers: gl.constexpr, num_acc_buffers: gl.constexpr, num_warps: gl.constexpr, ): """Warp-specialized wgrad kernel: grad_W = grad_out^T @ im2col(input). GEMM dimensions (per CTA): M = Co tile (output rows) N = Ci tile at fixed (r,s) (output cols) K = N_batch * out_h * out_w (spatial reduction, split across SPLIT_K CTAs) """ M_spatial = N * out_h * out_w config = WgradConfig( N, Ci, Co, R, S, gl.to_tensor(out_h), gl.to_tensor(out_w), gl.to_tensor(stride_h), gl.to_tensor(stride_w), pad_h, pad_w, K_GEMM, M_spatial, BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_buffers, num_warps, ) # a_bufs: grad_output tiles [BLOCK_K, BLOCK_M] (spatial × Co) # TMA loads from (M_spatial, Co), permuted to [BLOCK_M, BLOCK_K] at MMA call. a_smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_M], GL_GEMM_DTYPE) # b_bufs: im2col input tiles [BLOCK_K, BLOCK_N] (spatial × Ci) b_smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for([BLOCK_K, BLOCK_N], GL_GEMM_DTYPE) a_bufs = gl.allocate_shared_memory(GL_GEMM_DTYPE, [num_buffers, BLOCK_K, BLOCK_M], a_smem_layout) b_bufs = gl.allocate_shared_memory(GL_GEMM_DTYPE, [num_buffers, BLOCK_K, BLOCK_N], b_smem_layout) load_empty_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout()) load_ready_bars = gl.allocate_shared_memory(gl.int64, [num_buffers, 1], mbarrier.MBarrierLayout()) init_mbarrier_ring(load_empty_bars) init_mbarrier_ring(load_ready_bars) TMEM_BLOCK_M: gl.constexpr = 64 if BLOCK_M == 64 else 128 tmem_layout: gl.constexpr = TensorMemoryLayout(block=(TMEM_BLOCK_M, BLOCK_N), col_stride=1) acc_bufs = allocate_tensor_memory(gl.float32, [num_acc_buffers, BLOCK_M, BLOCK_N], tmem_layout) acc_empty_bars = gl.allocate_shared_memory(gl.int64, [num_acc_buffers, 1], mbarrier.MBarrierLayout()) acc_ready_bars = gl.allocate_shared_memory(gl.int64, [num_acc_buffers, 1], mbarrier.MBarrierLayout()) init_mbarrier_ring(acc_empty_bars) init_mbarrier_ring(acc_ready_bars) p = PartitionArgs( config, in_desc, grad_out_desc, grad_weight, gl.to_tensor(grad_weight_stride_0), a_bufs, b_bufs, load_empty_bars, load_ready_bars, acc_bufs, acc_empty_bars, acc_ready_bars, ) gl.warp_specialize([ (epilogue_partition, (p, )), (mma_partition, (p, )), (load_partition, (p, )), ], [1, 1], [24, 24]) # ===-----------------------------------------------------------------------===# # Autotuning # ===-----------------------------------------------------------------------===# def conv2d_wgrad_get_configs(pre_hook=None): return [ triton.Config( { "BLOCK_M": block_m, "BLOCK_N": block_n, "BLOCK_K": block_k, "SPLIT_K": split_k, "num_buffers": num_buffers, "num_acc_buffers": num_acc_buffers, }, num_warps=num_warps, pre_hook=pre_hook, ) for block_m in (64, 128) for block_n in (64, 128, 256) for block_k in (64, ) for split_k in (1, 2, 4, 8, 16, 32) for num_buffers in (3, 4) for num_acc_buffers in (2, ) for num_warps in (4, ) ] # ===-----------------------------------------------------------------------===# # Host-Side Entry Point # ===-----------------------------------------------------------------------===# def _prepare_wgrad_problem(input_nhwc, grad_output_nhwc, R, S, stride, padding): """Validate inputs, pad channels, and return derived quantities.""" if input_nhwc.dtype != TORCH_GEMM_DTYPE or grad_output_nhwc.dtype != TORCH_GEMM_DTYPE: raise ValueError( f"conv2d_wgrad expects bfloat16 input and grad-output tensors, got {input_nhwc.dtype} and {grad_output_nhwc.dtype}" ) stride_h, stride_w = normalize_2d(stride, "stride") pad_h, pad_w = normalize_2d(padding, "padding") if stride_h <= 0 or stride_w <= 0: raise ValueError(f"stride must be positive, got {(stride_h, stride_w)}") if pad_h < 0 or pad_w < 0: raise ValueError(f"padding must be non-negative, got {(pad_h, pad_w)}") N, H, W, Ci_orig = input_nhwc.shape N2, out_h, out_w, Co = grad_output_nhwc.shape assert N == N2, "Batch size mismatch" expected_out_h = (H + 2 * pad_h - R) // stride_h + 1 expected_out_w = (W + 2 * pad_w - S) // stride_w + 1 if out_h != expected_out_h or out_w != expected_out_w: raise ValueError("Grad-output shape mismatch: expected " f"({N}, {expected_out_h}, {expected_out_w}, {Co}) from input/filter geometry, got " f"({N2}, {out_h}, {out_w}, {Co}).") if out_h <= 0 or out_w <= 0: raise ValueError("Invalid convolution geometry for wgrad") input_nhwc = maybe_pad_ci_for_tma(input_nhwc) Ci = input_nhwc.shape[-1] K_GEMM = R * S * Ci return input_nhwc, grad_output_nhwc, Ci_orig, N, Ci, Co, out_h, out_w, stride_h, stride_w, pad_h, pad_w, K_GEMM def _allocate_wgrad_output(device, Co, K_GEMM): return torch.zeros((Co, K_GEMM), device=device, dtype=torch.float32) def _make_wgrad_descriptors(input_nhwc, grad_output_nhwc, Co, out_h, out_w, stride_h, stride_w, pad_h, pad_w, input_block_shape, grad_out_block_shape): """Create TMA descriptors for wgrad im2col and grad_output.""" # TMA im2col descriptor for the activation tensor [N, H, W, Ci] in NHWC. _, H, W, _ = input_nhwc.shape upper_h = (out_h - 1) * stride_h + 1 - H - pad_h upper_w = (out_w - 1) * stride_w + 1 - W - pad_w input_layout = gl.NVMMASharedLayout.get_default_for(input_block_shape, GL_GEMM_DTYPE) in_desc = TensorDescriptorIm2Col( base=input_nhwc, shape=list(input_nhwc.shape), strides=list(input_nhwc.stride()), block_shape=input_block_shape, layout=input_layout, padding="zero", element_strides=[1, stride_h, stride_w, 1], pixel_box_lower_corner=[-pad_h, -pad_w], pixel_box_upper_corner=[upper_h, upper_w], ) # TMA tiled descriptor for grad_output reshaped as (M_spatial, Co). M_spatial = input_nhwc.shape[0] * out_h * out_w grad_out_2d = grad_output_nhwc.reshape(M_spatial, Co) grad_out_layout = gl.NVMMASharedLayout.get_default_for(grad_out_block_shape, GL_GEMM_DTYPE) grad_out_desc = TensorDescriptor.from_tensor(grad_out_2d, grad_out_block_shape, grad_out_layout) return in_desc, grad_out_desc def _make_grid(num_sms, M_spatial, Co, Ci, R, S): def grid(meta): co_blocks = triton.cdiv(Co, meta["BLOCK_M"]) ci_blocks = triton.cdiv(Ci, meta["BLOCK_N"]) total_k_iters = triton.cdiv(M_spatial, meta["BLOCK_K"]) k_iters_per_split = triton.cdiv(total_k_iters, meta["SPLIT_K"]) active_split_k = triton.cdiv(total_k_iters, k_iters_per_split) total_tiles = co_blocks * R * S * ci_blocks * active_split_k return (min(num_sms, total_tiles), ) return grid def _get_active_split_k(M_spatial, BLOCK_K, SPLIT_K): total_k_iters = triton.cdiv(M_spatial, BLOCK_K) k_iters_per_split = triton.cdiv(total_k_iters, SPLIT_K) return triton.cdiv(total_k_iters, k_iters_per_split) def _get_safe_wgrad_active_split_k(M_spatial, Co, K_GEMM, kernel_meta): active_split_k = _get_active_split_k(M_spatial, kernel_meta["BLOCK_K"], kernel_meta["SPLIT_K"]) if active_split_k > 1: # The split-K workspace is indexed as row * stride + col inside the kernel. # Very large workspaces can exceed the addressing range supported by the generated code. workspace_elems = active_split_k * Co * K_GEMM if workspace_elems > (2**31 - 1): raise ValueError("wgrad split-K workspace exceeds safe indexing range: " f"active_split_k={active_split_k}, Co={Co}, K_GEMM={K_GEMM}") return active_split_k def _allocate_wgrad_split_k_workspace(device, active_split_k, Co, K_GEMM): return torch.empty((active_split_k * Co, K_GEMM), device=device, dtype=torch.float32) _wgrad_autotune_cache = {} def _make_wgrad_autotune_key( device, num_sms, N, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w, ): return ( torch.cuda.get_device_capability(device), num_sms, N, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w, ) def _make_wgrad_runner( input_nhwc, grad_output_nhwc, grad_weight_flat, *, N, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w, K_GEMM, num_sms, kernel_meta, ): M_spatial = N * out_h * out_w active_split_k = _get_safe_wgrad_active_split_k(M_spatial, Co, K_GEMM, kernel_meta) uses_split_k_workspace = active_split_k > 1 launch_output = grad_weight_flat if uses_split_k_workspace: launch_output = _allocate_wgrad_split_k_workspace(input_nhwc.device, active_split_k, Co, K_GEMM) in_desc, grad_out_desc = _make_wgrad_descriptors( input_nhwc, grad_output_nhwc, Co, out_h, out_w, stride_h, stride_w, pad_h, pad_w, [kernel_meta["BLOCK_K"], kernel_meta["BLOCK_N"]], [kernel_meta["BLOCK_K"], kernel_meta["BLOCK_M"]], ) grid = _make_grid(num_sms, M_spatial, Co, Ci, R, S) def run(): _launch_wgrad( conv2d_wgrad_kernel, grid, in_desc=in_desc, grad_out_desc=grad_out_desc, grad_weight=launch_output, N=N, Ci=Ci, Co=Co, R=R, S=S, out_h=out_h, out_w=out_w, stride_h=stride_h, stride_w=stride_w, pad_h=pad_h, pad_w=pad_w, K_GEMM=K_GEMM, kernel_meta=kernel_meta, ) if uses_split_k_workspace: _reduce_wgrad_split_k_partials(launch_output, grad_weight_flat, Co, K_GEMM, active_split_k) return run def _benchmark_wgrad_config( input_nhwc, grad_output_nhwc, *, N, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w, K_GEMM, num_sms, kernel_meta, ): try: grad_weight_flat = torch.empty((Co, K_GEMM), device=input_nhwc.device, dtype=torch.float32) run = _make_wgrad_runner( input_nhwc, grad_output_nhwc, grad_weight_flat, N=N, Ci=Ci, Co=Co, R=R, S=S, out_h=out_h, out_w=out_w, stride_h=stride_h, stride_w=stride_w, pad_h=pad_h, pad_w=pad_w, K_GEMM=K_GEMM, num_sms=num_sms, kernel_meta=kernel_meta, ) run() torch.cuda.synchronize() return triton.testing.do_bench(run) except Exception: return float("inf") def _select_wgrad_kernel_meta( input_nhwc, grad_output_nhwc, *, N, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w, K_GEMM, num_sms, ): cache_key = _make_wgrad_autotune_key(input_nhwc.device, num_sms, N, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w) cached = _wgrad_autotune_cache.get(cache_key) if cached is not None: return dict(cached) best_ms = float("inf") best_kernel_meta = None for config in conv2d_wgrad_get_configs(): kernel_meta = config.all_kwargs() ms = _benchmark_wgrad_config( input_nhwc, grad_output_nhwc, N=N, Ci=Ci, Co=Co, R=R, S=S, out_h=out_h, out_w=out_w, stride_h=stride_h, stride_w=stride_w, pad_h=pad_h, pad_w=pad_w, K_GEMM=K_GEMM, num_sms=num_sms, kernel_meta=kernel_meta, ) if ms < best_ms: best_ms = ms best_kernel_meta = dict(kernel_meta) if best_kernel_meta is None: raise RuntimeError("Failed to autotune conv2d_wgrad: no valid kernel configurations.") _wgrad_autotune_cache[cache_key] = dict(best_kernel_meta) return dict(best_kernel_meta) @triton.jit def reduce_split_k_partials_kernel( partial_ptr, grad_weight_ptr, partial_stride_0, grad_weight_stride_0, Co, K_GEMM, ACTIVE_SPLIT_K: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_n = tl.program_id(axis=1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) mask = (offs_m[:, None] < Co) & (offs_n[None, :] < K_GEMM) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for split_k_idx in range(ACTIVE_SPLIT_K): partial_rows = split_k_idx * Co + offs_m partial_offsets = partial_rows[:, None] * partial_stride_0 + offs_n[None, :] acc += tl.load(partial_ptr + partial_offsets, mask=mask, other=0.0) grad_weight_offsets = offs_m[:, None] * grad_weight_stride_0 + offs_n[None, :] tl.store(grad_weight_ptr + grad_weight_offsets, acc, mask=mask) def _reduce_wgrad_split_k_partials(partials, grad_weight_flat, Co, K_GEMM, active_split_k): BLOCK_M = 64 BLOCK_N = 64 grid = (triton.cdiv(Co, BLOCK_M), triton.cdiv(K_GEMM, BLOCK_N)) reduce_split_k_partials_kernel[grid]( partials, grad_weight_flat, partials.stride(0), grad_weight_flat.stride(0), Co, K_GEMM, ACTIVE_SPLIT_K=active_split_k, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, num_warps=4, ) def _launch_wgrad( kernel, grid, *, in_desc, grad_out_desc, grad_weight, N, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w, K_GEMM, kernel_meta=None, ): if kernel_meta is None: kernel_meta = {} kernel[grid]( in_desc, grad_out_desc, grad_weight, N, Ci, Co, R, S, out_h, out_w, stride_h, stride_w, pad_h, pad_w, K_GEMM, grad_weight.stride(0), **kernel_meta, ) def _finalize_wgrad_output(grad_weight_flat, Co, R, S, Ci, Ci_orig): result = grad_weight_flat.reshape(Co, R, S, Ci).to(TORCH_GEMM_DTYPE) if Ci != Ci_orig: result = result[:, :, :, :Ci_orig].contiguous() return result def conv2d_wgrad(input_nhwc, grad_output_nhwc, R, S, stride=1, padding=0): """Production wgrad entrypoint. Selects the best kernel configuration with host-side autotuning, then runs deterministic two-pass split-K when reduction is needed. """ (input_nhwc, grad_output_nhwc, Ci_orig, N, Ci, Co, out_h, out_w, stride_h, stride_w, pad_h, pad_w, K_GEMM) = \ _prepare_wgrad_problem(input_nhwc, grad_output_nhwc, R, S, stride, padding) grad_weight_flat = _allocate_wgrad_output(input_nhwc.device, Co, K_GEMM) num_sms = torch.cuda.get_device_properties(input_nhwc.device).multi_processor_count kernel_meta = _select_wgrad_kernel_meta( input_nhwc, grad_output_nhwc, N=N, Ci=Ci, Co=Co, R=R, S=S, out_h=out_h, out_w=out_w, stride_h=stride_h, stride_w=stride_w, pad_h=pad_h, pad_w=pad_w, K_GEMM=K_GEMM, num_sms=num_sms, ) run = _make_wgrad_runner( input_nhwc, grad_output_nhwc, grad_weight_flat, N=N, Ci=Ci, Co=Co, R=R, S=S, out_h=out_h, out_w=out_w, stride_h=stride_h, stride_w=stride_w, pad_h=pad_h, pad_w=pad_w, K_GEMM=K_GEMM, num_sms=num_sms, kernel_meta=kernel_meta, ) run() return _finalize_wgrad_output(grad_weight_flat, Co, R, S, Ci, Ci_orig) def _make_wgrad_fixed_kernel_meta(SPLIT_K, num_buffers, num_warps): # Keep the fixed path on a tile shape that is also covered by autotune configs. return { "BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "SPLIT_K": SPLIT_K, "num_buffers": num_buffers, "num_acc_buffers": 2, "num_warps": num_warps, } def conv2d_wgrad_fixed(input_nhwc, grad_output_nhwc, R, S, stride=1, padding=0, num_buffers=2, num_warps=4, SPLIT_K=1): """Fixed-config wgrad entrypoint used for CI and debugging. Runs the kernel with a fixed supported tile shape instead of autotuning, while still using deterministic two-pass split-K when reduction is needed. """ (input_nhwc, grad_output_nhwc, Ci_orig, N, Ci, Co, out_h, out_w, stride_h, stride_w, pad_h, pad_w, K_GEMM) = \ _prepare_wgrad_problem(input_nhwc, grad_output_nhwc, R, S, stride, padding) grad_weight_flat = _allocate_wgrad_output(input_nhwc.device, Co, K_GEMM) num_sms = torch.cuda.get_device_properties(input_nhwc.device).multi_processor_count kernel_meta = _make_wgrad_fixed_kernel_meta(SPLIT_K, num_buffers, num_warps) run = _make_wgrad_runner( input_nhwc, grad_output_nhwc, grad_weight_flat, N=N, Ci=Ci, Co=Co, R=R, S=S, out_h=out_h, out_w=out_w, stride_h=stride_h, stride_w=stride_w, pad_h=pad_h, pad_w=pad_w, K_GEMM=K_GEMM, num_sms=num_sms, kernel_meta=kernel_meta, ) run() return _finalize_wgrad_output(grad_weight_flat, Co, R, S, Ci, Ci_orig) # ===-----------------------------------------------------------------------===# # Unit Tests # ===-----------------------------------------------------------------------===# def _assert_wgrad_correct(wgrad_fn, N, Ci, H, W, Co, R, S, stride, padding, **kwargs): """Run wgrad_fn and compare against PyTorch autograd reference.""" torch.manual_seed(0) stride_h, stride_w = normalize_2d(stride, "stride") pad_h, pad_w = normalize_2d(padding, "padding") x_nchw = torch.randn((N, Ci, H, W), device="cuda", dtype=TORCH_GEMM_DTYPE) x_nhwc = x_nchw.permute(0, 2, 3, 1).contiguous() out_h = (H + 2 * pad_h - R) // stride_h + 1 out_w = (W + 2 * pad_w - S) // stride_w + 1 grad_out_nchw = torch.randn((N, Co, out_h, out_w), device="cuda", dtype=TORCH_GEMM_DTYPE) grad_out_nhwc = grad_out_nchw.permute(0, 2, 3, 1).contiguous() w_nchw = torch.randn((Co, Ci, R, S), device="cuda", dtype=TORCH_GEMM_DTYPE) w_ref = w_nchw.detach().requires_grad_(True) out_ref = torch.nn.functional.conv2d(x_nchw, w_ref, stride=(stride_h, stride_w), padding=(pad_h, pad_w)) out_ref.backward(grad_out_nchw) ref_grad_w_nhwc = w_ref.grad.permute(0, 2, 3, 1).contiguous() triton_grad_w = wgrad_fn(x_nhwc, grad_out_nhwc, R, S, stride=stride, padding=padding, **kwargs) torch.testing.assert_close(triton_grad_w, ref_grad_w_nhwc, atol=1, rtol=0.01) @pytest.mark.parametrize("wgrad_fn,N,Ci,H,W,Co,R,S,stride,padding", [ *[(conv2d_wgrad_fixed, N, Ci, H, W, Co, R, S, stride, padding) for N in (1, 128) for H, W in ((64, 64), (64, 32)) for Ci, Co in ((128, 128), (384, 384), (128, 384)) for R, S in ((1, 1), (2, 2), (3, 3), (1, 3)) for stride in (1, 2, 3) for padding in (0, 1)], (conv2d_wgrad_fixed, 16, 5, 32, 32, 96, 3, 3, 1, 1), # padded channels (conv2d_wgrad_fixed, 16, 96, 1, 8, 128, 1, 2, (1, 2), 0), # asymmetric stride (conv2d_wgrad_fixed, 16, 512, 2, 2, 768, 2, 2, (2, 2), 0), # small spatial ]) @pytest.mark.skipif(not is_blackwell(), reason="Requires Blackwell GPU (SM 10.x)") def test_op(wgrad_fn, N, Ci, H, W, Co, R, S, stride, padding): _assert_wgrad_correct(wgrad_fn, N, Ci, H, W, Co, R, S, stride, padding) # ===-----------------------------------------------------------------------===# # Benchmarking # ===-----------------------------------------------------------------------===# BATCH = [128] CHANNELS = [(384, 384)] SPATIAL = [(64, 64)] FILTER = [(3, 3)] STRIDE = [1] PADDING = [1] def _make_bench_inputs(N, H, W, Ci, Co, R, S, stride_val, pad_val): torch.manual_seed(0) out_h = (H + 2 * pad_val - R) // stride_val + 1 out_w = (W + 2 * pad_val - S) // stride_val + 1 x_nchw = torch.randn((N, Ci, H, W), device="cuda", dtype=TORCH_GEMM_DTYPE) x_nhwc = x_nchw.permute(0, 2, 3, 1).contiguous() grad_out_nchw = torch.randn((N, Co, out_h, out_w), device="cuda", dtype=TORCH_GEMM_DTYPE) grad_out_nhwc = grad_out_nchw.permute(0, 2, 3, 1).contiguous() return x_nchw, x_nhwc, grad_out_nchw, grad_out_nhwc def _benchmark_tflops(fn, *, N, H, W, Ci, Co, R, S, stride_val, pad_val): ms = triton.testing.do_bench(fn) out_h = (H + 2 * pad_val - R) // stride_val + 1 out_w = (W + 2 * pad_val - S) // stride_val + 1 flops = 2.0 * N * out_h * out_w * Co * Ci * R * S return flops * 1e-12 / (ms * 1e-3) bench_configs = [] for N, (Ci, Co), (H, W), (R, S), stride_val, pad_val in [(N, ch, sp, f, s, p) for N in BATCH for ch in CHANNELS for sp in SPATIAL for f in FILTER for s in STRIDE for p in PADDING]: bench_configs.append( triton.testing.Benchmark( x_names=["kernel"], x_vals=["autotuned"], line_arg="provider", line_vals=["gluon", "torch"], line_names=["Gluon (autotuned)", "PyTorch"], styles=[("green", "-"), ("blue", "-")], ylabel="TFLOPS", plot_name=f"Wgrad N={N} Ci={Ci} Co={Co} H={H} W={W} R={R} S={S} stride={stride_val} pad={pad_val}", args={ "N": N, "H": H, "W": W, "Ci": Ci, "Co": Co, "R": R, "S": S, "stride_val": stride_val, "pad_val": pad_val, }, )) @triton.testing.perf_report(bench_configs) def bench(N, H, W, Ci, Co, R, S, stride_val, pad_val, kernel, provider): x_nchw, x_nhwc, grad_out_nchw, grad_out_nhwc = \ _make_bench_inputs(N, H, W, Ci, Co, R, S, stride_val, pad_val) if provider == "gluon": fn = lambda: conv2d_wgrad(x_nhwc, grad_out_nhwc, R, S, stride=stride_val, padding=pad_val) elif provider == "torch": w_nchw = torch.randn((Co, Ci, R, S), device="cuda", dtype=TORCH_GEMM_DTYPE) fn = lambda: torch.ops.aten.convolution_backward( grad_out_nchw, x_nchw, w_nchw, bias_sizes=None, stride=[stride_val, stride_val], padding=[pad_val, pad_val], dilation=[1, 1], transposed=False, output_padding=[0, 0], groups=1, output_mask=[False, True, False], ) else: raise ValueError(f"Unsupported provider: {provider}") return _benchmark_tflops( fn, N=N, H=H, W=W, Ci=Ci, Co=Co, R=R, S=S, stride_val=stride_val, pad_val=pad_val, ) if __name__ == "__main__": bench.run(save_path=".", print_data=True) ```