TMA im2col mode and Convolution via Implicit GEMM
This tutorial explains NVIDIA TMA’s im2col mode and then applies it to an end-to-end convolution kernel.
We cover:
TMA im2col fundamentals on a 4-D NHWC tensor
How TensorDescriptorIm2Col parameters control access boundaries
Practical loading patterns (basic, padded, shifted-offset, multi-batch)
How 2D convolution works (the sliding window)
The im2col algorithm that reshapes convolution into GEMM
A convolution kernel using TMA im2col + MMA (works on both Hopper and Blackwell)
We re-use the MMA abstraction from 07-persistence.py so that the same kernel
runs on Hopper (WGMMA) and Blackwell (tcgen05 MMA) without code duplication.
Prerequisites:
04-tma.py: TMA basics (tensor descriptors, async copies)05-wgmma.py: Hopper warp-group MMA (WGMMA) basics06-tcgen05.py: tcgen05 MMA for matrix multiplication on Blackwell07-persistence.py: MMA abstraction (WGMMA / MMAv5) and persistent kernels
Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode
TMA im2col Mode Parameters: The following descriptor fields and boundary rules are specific to TensorDescriptorIm2Col in im2col mode.
Block Shape: block_shape = [pixelsPerColumn, channelsPerPixel] - pixelsPerColumn: number of pixels from spatial dims [N, H, W] - channelsPerPixel: number of channels per pixel from [C]
TensorDescriptor Parameters: - Input tensor: [N, H, W, C] in NHWC format - pixel_box_lower_corner: [H_lo, W_lo] offset from (0, 0) - pixel_box_upper_corner: [H_hi, W_hi] offset from [H - 1, W - 1] - element_strides: [1, 1, 1, 1] for contiguous element access - padding: “zero” or “nan” for out-of-bounds fills
Access Boundary: pixel_box_lower_corner, pixel_box_upper_corner, and offsets define the spatial access window: - Lower bound = pixel_box_lower_corner + offsets - Upper bound = [H - 1, W - 1] + pixel_box_upper_corner + offsets - Bounds are interpreted as closed (inclusive): [lower, upper]
The input tensor in global memory is NOT padded. When accessing outside
[0, 0] to [H-1, W-1], TMA fills with the padding value.
async_copy_global_to_shared_im2col: async_copy_global_to_shared_im2col(tensor_desc, coord, offsets, barrier, result) - coord: [batch_idx, start_h, start_w, channel_start] start coords - offsets: [h_offset, w_offset] spatial offsets (i16)
Starting position (the first pixel): (batch_idx, start_h + h_offset, start_w + w_offset, channel_start)
import importlib
import pytest
import torch
import triton
import triton.language as tl
from triton.experimental import gluon
from triton.experimental.gluon import language as ttgl
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor, TensorDescriptorIm2Col
from triton.experimental.gluon.language.nvidia.hopper import fence_async_shared, mbarrier, tma
t7 = importlib.import_module("07-persistence")
if __name__ == "__main__" and not t7.is_hopper_or_newer():
raise RuntimeError("This tutorial requires Hopper or newer NVIDIA GPU")
Tutorial Motivation
Before diving into kernels, we first explain the concept of im2col convolution: rewriting convolution as GEMM by flattening each sliding-window patch into a row.
Then we motivate why TMA im2col mode is helpful for this pattern: it performs convolution-style address generation and padding handling in hardware, so the kernel uses simpler indexing and can spend more resources on MMA compute.
What is a 2D Convolution?
A 2D convolution slides a small filter (kernel) across an input image and computes a weighted sum at every position. For a single-channel example with a 3x3 input and a 2x2 filter (stride=1, no padding), the output is 2x2:
Input (3x3): Filter (2x2): Output (2x2):
+---+---+---+ +----+----+ +----+----+
| a | b | c | | w0 | w1 | | y0 | y1 |
+---+---+---+ +----+----+ +----+----+
| d | e | f | | w2 | w3 | | y2 | y3 |
+---+---+---+ +----+----+ +----+----+
| g | h | i |
+---+---+---+
Each output element is the dot product of the filter with a region of the input:
y0 = a*w0 + b*w1 + d*w2 + e*w3
y1 = b*w0 + c*w1 + e*w2 + f*w3
y2 = d*w0 + e*w1 + g*w2 + h*w3
y3 = e*w0 + f*w1 + h*w2 + i*w3
With multiple input channels (Ci) and multiple output channels (Co):
y[n, oh, ow, co] = SUM over (r, s, ci) of:
input[n, oh*stride + r - pad, ow*stride + s - pad, ci] * weight[co, r, s, ci]
This nested summation over (r, s, ci) is the key insight for converting convolution into matrix multiplication.
The im2col Algorithm
im2col (image to column) rearranges the input so that convolution becomes a standard matrix multiplication (GEMM). The idea is:
For each output position (n, oh, ow), extract the R x S x Ci input patch that the filter overlaps with.
Flatten this patch into a single row of length
R * S * Ci.Stack all such rows into a matrix A of shape
[M, K].Reshape the filter into a matrix W of shape
[Co, K].The output is simply
A @ W^T(a GEMM!).
where M = N * out_h * out_w and K = R * S * Ci.
Example: 3x3 input, 2x2 filter, stride=1, no padding => 2x2 output
Step 1: Extract patches for each output position
-------------------------------------------------
y0 at (0,0): patch = [a, b, d, e]
y1 at (0,1): patch = [b, c, e, f]
y2 at (1,0): patch = [d, e, g, h]
y3 at (1,1): patch = [e, f, h, i]
Step 2: Stack patches into im2col matrix A (M=4, K=4)
-----------------------------------------------------
K = R*S*Ci = 4
<------------->
A = | a b d e | <- y0 ^
| b c e f | <- y1 | M = 4
| d e g h | <- y2 |
| e f h i | <- y3 v
Note: overlapping patches share input elements (e.g., 'e' appears in all 4 rows).
Step 3: Reshape filter into weight matrix W (Co=1, K=4)
-------------------------------------------------------
W = | w0 w1 w2 w3 |
Step 4: Output = A @ W^T
------------------------
| w0 | | y0 |
A @ W^T = | w1 | = | y1 |
(4x4)(4x1) | w2 | | y2 |
| w3 | | y3 |
(4x1)
With Co output channels, W is (Co x K) and Output is (M x Co).
Why TMA im2col mode is helpful for convolution:
In the matrix view above, each row of A is one output position’s patch.
The indexing term
oh*stride + r - pad/ow*stride + s - padmust be evaluated for every filter tap and channel block in the inner loop.Software address generation and bounds handling increase integer ALU work and register pressure right where we want to maximize MMA throughput.
TensorDescriptorIm2Colencodes stride/padding/pixel-box once, and TMA hardware computes the gather addresses and out-of-bounds padding fills.The kernel can focus on MMA scheduling while using simple
coord+offsetsloads, which is both cleaner and typically faster.
Example (multiple TMA im2col loads -> rows of A):
Use the same setup as above: 3x3 input, 2x2 filter, stride=1, pad=0.
Input (same symbols as the matrix example):
W=0 W=1 W=2
H=0 a b c
H=1 d e f
H=2 g h i
For each output position (oh, ow), TMA im2col uses:
coord = [n, oh*stride - pad, ow*stride - pad, ci_blk]
Then one load is issued per filter offset (r, s):
offsets=[0,0], [0,1], [1,0], [1,1]
(oh=0, ow=0), coord=[n,0,0,ci_blk] -> first column of A:
[ a,
b,
d,
e ]
The width/channel dimension is Ci.
(oh=0, ow=1), coord=[n,0,1,ci_blk] -> [b, c, e, f] = second column of A
(oh=1, ow=0), coord=[n,1,0,ci_blk] -> [d, e, g, h] = third column of A
(oh=1, ow=1), coord=[n,1,1,ci_blk] -> [e, f, h, i] = fourth column of A
This is exactly the "image to column" idea: each output-position patch
from the image is flattened and stored as one column of A.
Stacking these multiple loads across output positions forms A:
| a b d e |
| b c e f |
| d e g h |
| e f h i |
In the real kernel, this is executed in the K-loop as one
async_copy_global_to_shared_im2col per (r, s, ci_block).
Example 1: Basic Loading (No Padding)
Load all 16 pixels (4x4 grid) from a [1, 4, 4, 32] tensor.
coord = [0, 0, 0, 0], offsets = [0, 0]
Access boundary: H in [0, 3], W in [0, 3] (original tensor)
def run_tma_im2col_simple():
inp, out = run_tma_im2col(
title="Example 1: Basic loading (no padding)",
pixel_box_lower_corner=[0, 0],
pixel_box_upper_corner=[0, 0],
coord=[0, 0, 0, 0],
offsets=[0, 0],
)
torch.testing.assert_close(out, inp.reshape(16, 32), atol=0, rtol=0)
@pytest.mark.skipif(not t7.is_hopper_or_newer(), reason="Requires Hopper")
def test_tma_im2col_simple():
run_tma_im2col_simple()
Example 1 Explanation:
Configuration:
pixel_box_lower_corner = [0, 0], pixel_box_upper_corner = [0, 0] (no padding)
coord = [0, 0, 0, 0], offsets = [0, 0]
Access boundary: H in [0, 3], W in [0, 3]
Expected output:
Loaded data: pixel 0->1, pixel 1->2, ..., pixel 15->16
Input [1, 4, 4, 32] spatial layout: Output [16, 32] pixel order:
Channel dimension is contiguous: each pixel stores 32 channels in order.
+----+----+----+----+ pixel 0 = (H=0,W=0) -> value 1
| 1 | 2 | 3 | 4 | H=0 pixel 1 = (H=0,W=1) -> value 2
+----+----+----+----+ ...
| 5 | 6 | 7 | 8 | H=1 pixel 15 = (H=3,W=3) -> value 16
+----+----+----+----+
| 9 | 10 | 11 | 12 | H=2
+----+----+----+----+
| 13 | 14 | 15 | 16 | H=3
+----+----+----+----+
W=0 W=1 W=2 W=3
if __name__ == "__main__":
run_tma_im2col_simple()
print("\n" + "=" * 50 + "\n")
Example 1: Basic loading (no padding)
Loaded data (pixel_id, value):
pixel 0: 1
pixel 1: 2
pixel 2: 3
pixel 3: 4
pixel 4: 5
pixel 5: 6
pixel 6: 7
pixel 7: 8
pixel 8: 9
pixel 9: 10
pixel 10: 11
pixel 11: 12
pixel 12: 13
pixel 13: 14
pixel 14: 15
pixel 15: 16
==================================================
Example 2: Loading from Padded Region
With pixel_box_lower_corner=[-1,-1] and pixel_box_upper_corner=[-1,-1], access boundary is H in [-1, 2], W in [-1, 2]. Negative coordinates are outside tensor -> filled with padding (0).
def run_tma_im2col_padded():
_, out = run_tma_im2col(
title="Example 2: Loading from padded region",
pixel_box_lower_corner=[-1, -1],
pixel_box_upper_corner=[-1, -1],
coord=[0, -1, -1, 0],
offsets=[0, 0],
)
# Expected: row H=-1 all padded, then each row has W=-1 padded
expected = torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 5, 6, 7, 0, 9, 10, 11], device="cuda", dtype=torch.float32)
torch.testing.assert_close(out[:, 0], expected, atol=0, rtol=0)
@pytest.mark.skipif(not t7.is_hopper_or_newer(), reason="Requires Hopper")
def test_tma_im2col_padded():
run_tma_im2col_padded()
Example 2 Explanation:
Configuration:
pixel_box_lower_corner = [-1, -1], pixel_box_upper_corner = [-1, -1]
coord = [0, -1, -1, 0], offsets = [0, 0]
Access Boundary with offsets=[0, 0]:
Lower = pixel_box_lower_corner + offsets = [-1, -1] + [0, 0]
Upper = [H - 1, W - 1] + pixel_box_upper_corner + offsets = [3, 3] + [-1, -1] + [0, 0]
Result: H in [-1, 2], W in [-1, 2]
Negative coordinates are outside tensor -> filled with padding (0)
Expected output:
pixel 0: 0 (H=-1, W=-1) padded pixel 8: 0 (H=1, W=-1) padded
pixel 1: 0 (H=-1, W=0) padded pixel 9: 5 (H=1, W=0) actual
pixel 2: 0 (H=-1, W=1) padded pixel 10: 6 (H=1, W=1) actual
pixel 3: 0 (H=-1, W=2) padded pixel 11: 7 (H=1, W=2) actual
pixel 4: 0 (H=0, W=-1) padded pixel 12: 0 (H=2, W=-1) padded
pixel 5: 1 (H=0, W=0) actual pixel 13: 9 (H=2, W=0) actual
pixel 6: 2 (H=0, W=1) actual pixel 14: 10 (H=2, W=1) actual
pixel 7: 3 (H=0, W=2) actual pixel 15: 11 (H=2, W=2) actual
Visualization:
Original 4x4 tensor: Access boundary H in [-1,2], W in [-1,2]:
Note: padded cells are not stored in global memory; TMA fills them on load.
W=0 W=1 W=2 W=3 W=-1 W=0 W=1 W=2
+----+----+----+----+ +----+----+----+----+
H=0 | 1 | 2 | 3 | 4 | H=-1| 0 | 0 | 0 | 0 | <- padded
+----+----+----+----+ +----+----+----+----+
H=1 | 5 | 6 | 7 | 8 | H=0 | 0 | 1 | 2 | 3 |
+----+----+----+----+ +----+----+----+----+
H=2 | 9 | 10 | 11 | 12 | H=1 | 0 | 5 | 6 | 7 |
+----+----+----+----+ +----+----+----+----+
H=3 | 13 | 14 | 15 | 16 | H=2 | 0 | 9 | 10 | 11 |
+----+----+----+----+ +----+----+----+----+
^
padded (W<0)
if __name__ == "__main__":
run_tma_im2col_padded()
print("\n" + "=" * 50 + "\n")
Example 2: Loading from padded region
Loaded data (pixel_id, value):
pixel 0: 0
pixel 1: 0
pixel 2: 0
pixel 3: 0
pixel 4: 0
pixel 5: 1
pixel 6: 2
pixel 7: 3
pixel 8: 0
pixel 9: 5
pixel 10: 6
pixel 11: 7
pixel 12: 0
pixel 13: 9
pixel 14: 10
pixel 15: 11
==================================================
Example 3: Offset Shifts Access Boundary
Same TensorDescriptor as Example 2, but with offsets=[1, 1]. This shifts the access boundary from [-1, 2] to [0, 3], covering the original tensor.
def run_tma_im2col_offset():
inp, out = run_tma_im2col(
title="Example 3: Offset shifts access to original tensor",
pixel_box_lower_corner=[-1, -1],
pixel_box_upper_corner=[-1, -1],
coord=[0, -1, -1, 0],
offsets=[1, 1],
)
# Offset shifts access to exactly the original tensor [0,3] x [0,3]
torch.testing.assert_close(out, inp.reshape(16, 32), atol=0, rtol=0)
@pytest.mark.skipif(not t7.is_hopper_or_newer(), reason="Requires Hopper")
def test_tma_im2col_offset():
run_tma_im2col_offset()
Example 3 Explanation:
Configuration:
pixel_box_lower_corner = [-1, -1], pixel_box_upper_corner = [-1, -1]
coord = [0, -1, -1, 0], offsets = [1, 1]
Access Boundary with offsets=[1, 1]:
Lower = pixel_box_lower_corner + offsets = [-1, -1] + [1, 1]
Upper = [H - 1, W - 1] + pixel_box_upper_corner + offsets = [3, 3] + [-1, -1] + [1, 1]
Result: H in [0, 3], W in [0, 3] -> exactly the original tensor
Comparison (Example 2 vs Example 3):
Example 2 (offsets=[0,0]): Example 3 (offsets=[1,1]):
Access H in [-1, 2], W in [-1, 2] Access H in [0, 3], W in [0, 3]
W=-1 W=0 W=1 W=2 W=0 W=1 W=2 W=3
+----+----+----+----+ +----+----+----+----+
H=-1| 0 | 0 | 0 | 0 | H=0 | 1 | 2 | 3 | 4 |
+----+----+----+----+ +----+----+----+----+
H=0 | 0 | 1 | 2 | 3 | H=1 | 5 | 6 | 7 | 8 |
+----+----+----+----+ +----+----+----+----+
H=1 | 0 | 5 | 6 | 7 | H=2 | 9 | 10 | 11 | 12 |
+----+----+----+----+ +----+----+----+----+
H=2 | 0 | 9 | 10 | 11 | H=3 | 13 | 14 | 15 | 16 |
+----+----+----+----+ +----+----+----+----+
Key insight: offsets=[1,1] shifts the 4x4 access window from
the padded region to exactly cover the original tensor data.
if __name__ == "__main__":
run_tma_im2col_offset()
print("\n" + "=" * 50 + "\n")
Example 3: Offset shifts access to original tensor
Loaded data (pixel_id, value):
pixel 0: 1
pixel 1: 2
pixel 2: 3
pixel 3: 4
pixel 4: 5
pixel 5: 6
pixel 6: 7
pixel 7: 8
pixel 8: 9
pixel 9: 10
pixel 10: 11
pixel 11: 12
pixel 12: 13
pixel 13: 14
pixel 14: 15
pixel 15: 16
==================================================
Example 4: Loading Across Multiple Batches
This example shows loading pixels that span across multiple images (batches). Input: [2, 4, 4, 32] - 2 batches of 4x4 images with 32 channels Starting at coord=[0, 1, 3, 0] (batch 0, H=1, W=3) and loading 16 pixels wraps from batch 0 into batch 1.
def run_tma_im2col_multi_batch():
# Input: [2, 4, 4, 32] - 2 batches of 4x4 images (32 pixels total)
# Values 1-32 for easy tracking
inp = torch.arange(1, 33, device="cuda", dtype=torch.float32).unsqueeze(1).repeat(1, 32).reshape(2, 4, 4, 32)
out = torch.zeros(16, 32, device="cuda", dtype=torch.float32)
block_shape = [16, 32] # load 16 pixels
layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=32, rank=2)
in_desc = TensorDescriptorIm2Col.from_tensor(
inp,
block_shape,
layout,
padding="zero",
element_strides=[1, 1, 1, 1],
pixel_box_lower_corner=[0, 0],
pixel_box_upper_corner=[0, 0],
)
out_desc = TensorDescriptor.from_tensor(out, block_shape, layout)
# coord=[0, 1, 3, 0] - start from batch 0, H=1, W=3, C=0
# This is pixel index 1*4 + 3 = 7 in batch 0
tma_im2col_kernel[(1, )](in_desc, out_desc, 0, 1, 3, 0, 0, 0, num_warps=1)
print("Example 4: Loading across multiple batches")
print("Input shape: [2, 4, 4, 32] (2 batches of 4x4 images)")
print("Starting at: batch=0, H=1, W=3 (pixel 8 in input)")
print("\nLoaded data (pixel_id, value):")
for pixel_id in range(16):
value = int(out[pixel_id, 0].item())
print(f" pixel {pixel_id:2d}: value {value:2d}")
# Expected: values 8-16 from batch 0, then 17-23 from batch 1
expected = torch.tensor([8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], device="cuda",
dtype=torch.float32)
torch.testing.assert_close(out[:, 0], expected, atol=0, rtol=0)
@pytest.mark.skipif(not t7.is_hopper_or_newer(), reason="Requires Hopper")
def test_tma_im2col_multi_batch():
run_tma_im2col_multi_batch()
Example 4 Explanation:
Configuration:
Input: [2, 4, 4, 32] - 2 batches of 4x4 images
block_shape = [16, 32] - load 16 pixels
coord = [0, 1, 3, 0] - start at batch 0, H=1, W=3
Input layout (values 1-32):
Batch 0 (4x4): Batch 1 (4x4):
[x] = accessed pixel [x] = accessed pixel
W=0 W=1 W=2 W=3 W=0 W=1 W=2 W=3
+----+----+----+----+ +----+----+----+----+
H=0 | 1 | 2 | 3 | 4 | H=0 |[17]|[18]|[19]|[20]|
+----+----+----+----+ +----+----+----+----+
H=1 | 5 | 6 | 7 | [8]| <-start H=1 |[21]|[22]|[23]| 24 |
+----+----+----+----+ +----+----+----+----+
H=2 |[ 9]|[10]|[11]|[12]| H=2 | 25 | 26 | 27 | 28 |
+----+----+----+----+ +----+----+----+----+
H=3 |[13]|[14]|[15]|[16]| H=3 | 29 | 30 | 31 | 32 |
+----+----+----+----+ +----+----+----+----+
Starting at (N=0, H=1, W=3), loading 16 pixels in row-major order: pixel 0: (N=0, H=1, W=3) -> value 8 (start position) pixel 1: (N=0, H=2, W=0) -> value 9 pixel 2: (N=0, H=2, W=1) -> value 10 pixel 3: (N=0, H=2, W=2) -> value 11 pixel 4: (N=0, H=2, W=3) -> value 12 pixel 5: (N=0, H=3, W=0) -> value 13 pixel 6: (N=0, H=3, W=1) -> value 14 pixel 7: (N=0, H=3, W=2) -> value 15 pixel 8: (N=0, H=3, W=3) -> value 16 (last pixel in batch 0) pixel 9: (N=1, H=0, W=0) -> value 17 (wraps to batch 1) pixel 10: (N=1, H=0, W=1) -> value 18 pixel 11: (N=1, H=0, W=2) -> value 19 pixel 12: (N=1, H=0, W=3) -> value 20 pixel 13: (N=1, H=1, W=0) -> value 21 pixel 14: (N=1, H=1, W=1) -> value 22 pixel 15: (N=1, H=1, W=2) -> value 23
Key insight: Loading wraps from batch 0 into batch 1 seamlessly, enabling efficient access patterns for convolution operations.
if __name__ == "__main__":
run_tma_im2col_multi_batch()
print("\n" + "=" * 50 + "\n")
Example 4: Loading across multiple batches
Input shape: [2, 4, 4, 32] (2 batches of 4x4 images)
Starting at: batch=0, H=1, W=3 (pixel 8 in input)
Loaded data (pixel_id, value):
pixel 0: value 8
pixel 1: value 9
pixel 2: value 10
pixel 3: value 11
pixel 4: value 12
pixel 5: value 13
pixel 6: value 14
pixel 7: value 15
pixel 8: value 16
pixel 9: value 17
pixel 10: value 18
pixel 11: value 19
pixel 12: value 20
pixel 13: value 21
pixel 14: value 22
pixel 15: value 23
==================================================
Example 5: Multi-Batch with Padded Access Boundary
This example combines multi-batch loading with padding. With pixel_box_lower_corner=[-1,-1] and pixel_box_upper_corner=[-1,-1], the access boundary is H in [-1, 2], W in [-1, 2] per batch. Starting at H=1, W=2 loads pixels that wrap across batches within this constrained boundary.
def run_tma_im2col_multi_batch_padded():
# Input: [2, 4, 4, 32] - 2 batches of 4x4 images (32 pixels total)
inp = torch.arange(1, 33, device="cuda", dtype=torch.float32).unsqueeze(1).repeat(1, 32).reshape(2, 4, 4, 32)
out = torch.zeros(16, 32, device="cuda", dtype=torch.float32)
block_shape = [16, 32]
layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=32, rank=2)
in_desc = TensorDescriptorIm2Col.from_tensor(
inp,
block_shape,
layout,
padding="zero",
element_strides=[1, 1, 1, 1],
pixel_box_lower_corner=[-1, -1],
pixel_box_upper_corner=[-1, -1],
)
out_desc = TensorDescriptor.from_tensor(out, block_shape, layout)
# coord=[0, 1, 2, 0] - start from batch 0, H=1, W=2, C=0
tma_im2col_kernel[(1, )](in_desc, out_desc, 0, 1, 2, 0, 0, 0, num_warps=1)
print("Example 5: Multi-batch with padded access boundary")
print("Input shape: [2, 4, 4, 32], pixel_box_lower_corner=[-1,-1], "
"pixel_box_upper_corner=[-1,-1]")
print("Access boundary per batch: H in [-1, 2], W in [-1, 2]")
print("Starting at: batch=0, H=1, W=2")
print("\nLoaded data (pixel_id, value):")
for pixel_id in range(16):
value = int(out[pixel_id, 0].item())
print(f" pixel {pixel_id:2d}: {value}")
# Expected: 7, then padding + data pattern as described in documentation
expected = torch.tensor([7, 0, 9, 10, 11, 0, 0, 0, 0, 0, 17, 18, 19, 0, 21, 22], device="cuda", dtype=torch.float32)
torch.testing.assert_close(out[:, 0], expected, atol=0, rtol=0)
@pytest.mark.skipif(not t7.is_hopper_or_newer(), reason="Requires Hopper")
def test_tma_im2col_multi_batch_padded():
run_tma_im2col_multi_batch_padded()
Example 5 Explanation:
Configuration:
Input: [2, 4, 4, 32] - 2 batches of 4x4 images
pixel_box_lower_corner = [-1, -1], pixel_box_upper_corner = [-1, -1]
Lower = pixel_box_lower_corner + offsets = [-1, -1] + [0, 0]
Upper = [H - 1, W - 1] + pixel_box_upper_corner + offsets = [3, 3] + [-1, -1] + [0, 0]
Access boundary per batch: H in [-1, 2], W in [-1, 2] (4x4 shifted region)
coord = [0, 1, 2, 0] - start at batch 0, H=1, W=2
Access boundary visualization (per batch):
Batch 0 - Original 4x4 tensor: Batch 0 - Access boundary H in [-1,2], W in [-1,2]:
[x] = accessed pixel [x] = accessed pixel, [0] = accessed padding
W=0 W=1 W=2 W=3 W=-1 W=0 W=1 W=2
+----+----+----+----+ +----+----+----+----+
H=0 | 1 | 2 | 3 | 4 | H=-1| 0 | 0 | 0 | 0 |
+----+----+----+----+ +----+----+----+----+
H=1 | 5 | 6 | [7]| 8 | H=0 | 0 | 1 | 2 | 3 |
+----+----+----+----+ +----+----+----+----+
H=2 | [9]|[10]|[11]| 12 | H=1 | 0 | 5 | 6 | [7]| <-start
+----+----+----+----+ +----+----+----+----+
H=3 | 13 | 14 | 15 | 16 | H=2 | [0]| [9]|[10]|[11]|
+----+----+----+----+ +----+----+----+----+
Batch 1 - Original 4x4 tensor: Batch 1 - Access boundary H in [-1,2], W in [-1,2]:
[x] = accessed pixel [x] = accessed pixel, [0] = accessed padding
W=0 W=1 W=2 W=3 W=-1 W=0 W=1 W=2
+----+----+----+----+ +----+----+----+----+
H=0 |[17]|[18]|[19]| 20 | H=-1| [0]| [0]| [0]| [0]| <- accessed padding
+----+----+----+----+ +----+----+----+----+
H=1 |[21]|[22]| 23 | 24 | H=0 | [0]|[17]|[18]|[19]|
+----+----+----+----+ +----+----+----+----+
H=2 | 25 | 26 | 27 | 28 | H=1 | [0]|[21]|[22]| 23 |
+----+----+----+----+ +----+----+----+----+
H=3 | 29 | 30 | 31 | 32 | H=2 | 0 | 25 | 26 | 27 |
+----+----+----+----+ +----+----+----+----+
Starting at H=1, W=2 (value 7), loading 16 pixels within the boundary: pixel 0: (N=0, H=1, W=2) -> value 7 (start position) pixel 1: (N=0, H=2, W=-1) -> value 0 (padded, W=-1) pixel 2: (N=0, H=2, W=0) -> value 9 pixel 3: (N=0, H=2, W=1) -> value 10 pixel 4: (N=0, H=2, W=2) -> value 11 pixel 5: (N=1, H=-1, W=-1) -> value 0 (wraps to batch 1, padded) pixel 6: (N=1, H=-1, W=0) -> value 0 (padded) pixel 7: (N=1, H=-1, W=1) -> value 0 (padded) pixel 8: (N=1, H=-1, W=2) -> value 0 (padded) pixel 9: (N=1, H=0, W=-1) -> value 0 (padded) pixel 10: (N=1, H=0, W=0) -> value 17 pixel 11: (N=1, H=0, W=1) -> value 18 pixel 12: (N=1, H=0, W=2) -> value 19 pixel 13: (N=1, H=1, W=-1) -> value 0 (padded) pixel 14: (N=1, H=1, W=0) -> value 21 pixel 15: (N=1, H=1, W=1) -> value 22
Key insight: The access boundary constrains the 4x4 window per batch. When wrapping to the next batch, the boundary resets and padded regions (H=-1 or W=-1) are filled with zeros. This is useful for convolution operations, where it allows seamlessly processing pixel data across image boundaries in a single block.
if __name__ == "__main__":
run_tma_im2col_multi_batch_padded()
print("\n" + "=" * 50 + "\n")
Example 5: Multi-batch with padded access boundary
Input shape: [2, 4, 4, 32], pixel_box_lower_corner=[-1,-1], pixel_box_upper_corner=[-1,-1]
Access boundary per batch: H in [-1, 2], W in [-1, 2]
Starting at: batch=0, H=1, W=2
Loaded data (pixel_id, value):
pixel 0: 7
pixel 1: 0
pixel 2: 9
pixel 3: 10
pixel 4: 11
pixel 5: 0
pixel 6: 0
pixel 7: 0
pixel 8: 0
pixel 9: 0
pixel 10: 17
pixel 11: 18
pixel 12: 19
pixel 13: 0
pixel 14: 21
pixel 15: 22
==================================================
Implicit GEMM with TMA im2col
Explicit im2col physically constructs the matrix A in memory. This is simple but wastes memory because overlapping patches duplicate input data. For an R x S filter, each input element may be copied up to R * S times!
Implicit GEMM avoids materializing A. During the GEMM K-loop, it computes
input addresses on the fly for each filter position (r, s) and channel
block. On Hopper+ GPUs, TMA provides a dedicated im2col mode that performs
this address generation in hardware. We configure one
TensorDescriptorIm2Col (see the launcher), and each TMA load receives a
filter offset [r, s] that selects which kernel position to gather. TMA
then handles output-position strides, zero-padding for out-of-bounds
accesses, and batch-boundary wrapping automatically, with no register-level
index arithmetic.
Pseudocode:
# GEMM dimensions
M = N * out_h * out_w # output spatial positions
N_gemm = Co # output channels
K = R * S * Ci # reduction dimension
for each M-tile (output positions):
for each N-tile (output channels):
acc = zeros(BLOCK_M, BLOCK_N)
for r in range(R): # filter height
for s in range(S): # filter width
for ci_block in range(Ci // BLOCK_K):
# Input tile via TMA im2col:
# input[batch, out_y*stride+r-pad, out_x*stride+s-pad, ci_block*BLOCK_K:...]
tma.async_load_im2col(...)
A_tile = a_smem # [BLOCK_M, BLOCK_K]
# Weight tile via standard TMA:
# weight[co_start:..., r, s, ci_block*BLOCK_K:...]
tma.async_load(...)
B_tile = b_smem # [BLOCK_N, BLOCK_K]
# Matrix multiply-accumulate (details omitted here)
acc += A_tile @ B_tile^T
# Store output tile via TMA
tma.async_copy_shared_to_global(...)
Key insight: we never materialize the full im2col matrix; address generation happens inside the K-loop via TMA im2col.
Gluon Kernel
The kernel below implements a single-buffered implicit GEMM convolution.
store_output_tile is factored out as a helper; everything else is
inline so the reader can follow the im2col algorithm top-to-bottom.
MMAImpl (from 07-persistence.py) lets the same kernel run on Hopper
(WGMMA, accumulator in registers) and Blackwell (tcgen05, accumulator in
tensor memory) without any code changes.
Convolution data layout convention for this section:
Input: NHWC
[N, H, W, Ci]Weight:
[Co, R, S, Ci](output-channels first)Output: NHWC
[N, out_h, out_w, Co]
@gluon.jit
def decompose_m_offset(pid_m, BLOCK_M: ttgl.constexpr, out_h, out_w):
"""Map M-tile index to (offs_m, batch, out_y, out_x) for TMA im2col coordinates."""
offs_m = pid_m * BLOCK_M
batch_id = offs_m // (out_h * out_w)
m_residual = offs_m % (out_h * out_w)
out_y = m_residual // out_w
out_x = m_residual % out_w
return offs_m, batch_id, out_y, out_x
@gluon.jit
def init_accumulator(in_desc, weight_desc, MMAImpl, dtype, BLOCK_M: ttgl.constexpr, BLOCK_N: ttgl.constexpr,
num_warps: ttgl.constexpr):
"""Allocate shared-memory tiles, MMA accumulator, and TMA barrier."""
a_smem = ttgl.allocate_shared_memory(dtype, in_desc.block_shape, in_desc.layout)
b_smem = ttgl.allocate_shared_memory(dtype, weight_desc.block_shape, weight_desc.layout)
mma = MMAImpl.initialize(dtype, BLOCK_M, BLOCK_N, num_warps)
tma_bar = ttgl.allocate_shared_memory(ttgl.int64, [1], mbarrier.MBarrierLayout())
mbarrier.init(tma_bar, count=1)
return a_smem, b_smem, mma, tma_bar
@gluon.jit
def store_output_tile(mma, dtype, out_desc, offs_m, offs_n):
"""Wait for the final MMA, downcast, and write the output tile via TMA."""
mma = mma.wait_num_outstanding(0)
acc, mma = mma.take_result()
c_smem = ttgl.allocate_shared_memory(dtype, out_desc.block_shape, out_desc.layout)
c_smem.store(acc.to(dtype))
fence_async_shared()
tma.async_copy_shared_to_global(out_desc, [offs_m, offs_n], c_smem)
tma.store_wait(pendings=0)
@gluon.jit
def conv2d_im2col_kernel(
in_desc,
weight_desc,
out_desc,
R,
S,
Ci,
out_h,
out_w,
pad_h,
pad_w,
stride_h: ttgl.constexpr,
stride_w: ttgl.constexpr,
MMAImpl: ttgl.constexpr,
BLOCK_M: ttgl.constexpr,
BLOCK_N: ttgl.constexpr,
BLOCK_K: ttgl.constexpr,
num_warps: ttgl.constexpr,
):
"""
Implicit GEMM convolution kernel using TMA im2col + MMA.
Works on both Hopper (WGMMA) and Blackwell (tcgen05) via MMAImpl.
"""
dtype: ttgl.constexpr = in_desc.dtype
# for each M-tile / N-tile:
pid_m, pid_n = ttgl.program_id(0), ttgl.program_id(1)
offs_m, batch_id, out_y, out_x = decompose_m_offset(pid_m, BLOCK_M, out_h, out_w)
# acc = zeros(BLOCK_M, BLOCK_N)
a_smem, b_smem, mma, tma_bar = init_accumulator(in_desc, weight_desc, MMAImpl, dtype, BLOCK_M, BLOCK_N, num_warps)
phase = 0
# for r in range(R): for s in range(S): for ci_blk in range(ceil(Ci/BLOCK_K))
# When Ci is not a multiple of BLOCK_K, last channel block is partial; TMA zero-fills OOB.
ci_num_blocks = ttgl.cdiv(Ci, BLOCK_K)
total_k_iters = R * S * ci_num_blocks
for k_iter in range(total_k_iters):
ci_block, rs_idx = k_iter % ci_num_blocks, k_iter // ci_num_blocks
r, s = rs_idx // S, rs_idx % S
# A = load_input[batch, oh*stride+r-pad, ow*stride+s-pad, ci_blk*BLOCK_K:…]
# Equivalent TMA-im2col mapping:
# coord = [batch_id, out_y*stride_h-pad_h, out_x*stride_w-pad_w, ci_block*BLOCK_K]
# offsets = [r, s]
# TMA applies offsets to the spatial coords, so start is:
# [batch_id, out_y*stride_h-pad_h+r, out_x*stride_w-pad_w+s, ci_block*BLOCK_K]
mbarrier.expect(tma_bar, in_desc.block_type.nbytes + weight_desc.block_type.nbytes)
tma.async_load_im2col(
in_desc,
[batch_id, out_y * stride_h - pad_h, out_x * stride_w - pad_w, ci_block * BLOCK_K],
[r.to(tl.int16), s.to(tl.int16)],
tma_bar,
a_smem,
)
# B = load_weight[co_start:…, k_offset:k_offset+BLOCK_K]. Uses Ci (not
# ci_num_blocks*BLOCK_K) as the per-(r,s) stride so offsets stay aligned with the
# flat (Co, R*S*Ci) layout. When the last ci block bleeds into the next (r,s)
# group, those weight elements are multiplied by zero-filled input channels (TMA
# zero-fills input channels past Ci), so the result is still correct.
k_offset = r * S * Ci + s * Ci + ci_block * BLOCK_K
tma.async_load(weight_desc, [pid_n * BLOCK_N, k_offset], tma_bar, b_smem)
mbarrier.wait(tma_bar, phase=phase)
# acc += A @ B^T
mma = mma.wait_num_outstanding(0)
mma = mma.issue_async_mma(a_smem, b_smem.permute((1, 0)))
phase ^= 1
mbarrier.invalidate(tma_bar)
# store output[...] = acc
store_output_tile(mma, dtype, out_desc, offs_m, pid_n * BLOCK_N)
Host-Side Launcher
The host side sets up TMA descriptors and launches the kernel.
The critical part is configuring the TensorDescriptorIm2Col with the
correct pixel_box, element_strides, and padding to match
the convolution parameters.
t7.select_mma_impl() automatically picks the right MMA backend:
WGMMA on Hopper (SM 9.x) or MMAv5 on Blackwell (SM 10.x).
def conv2d_tma_im2col(input_nhwc, weight, stride=1, padding=0, BLOCK_M=64, BLOCK_N=64, BLOCK_K=64, num_warps=4):
"""
2D convolution using TMA im2col + MMA (Hopper / Blackwell).
Args:
input_nhwc: [N, H, W, Ci] in NHWC format, float16
weight: [Co, R, S, Ci] filter tensor, float16
stride: convolution stride (default: 1)
padding: convolution padding (default: 0)
"""
MMAImpl = t7.select_mma_impl()
assert MMAImpl is not None, "conv2d_tma_im2col requires a Hopper or Blackwell GPU"
N, H, W, Ci = input_nhwc.shape
Co, R, S, Ci_w = weight.shape
assert Ci == Ci_w, "Channel mismatch"
assert stride > 0, f"stride must be positive, got {stride}"
assert padding >= 0, f"padding must be non-negative, got {padding}"
# Ci may be any positive value; when not divisible by BLOCK_K, last tile is partial (TMA zero-fills OOB).
out_h = (H + 2 * padding - R) // stride + 1
out_w = (W + 2 * padding - S) // stride + 1
assert out_h > 0 and out_w > 0, (f"Invalid convolution geometry: out_h={out_h}, out_w={out_w}, "
f"H={H}, W={W}, R={R}, S={S}, stride={stride}, padding={padding}")
M_GEMM = N * out_h * out_w
N_GEMM = Co
output = torch.empty(M_GEMM, Co, device=input_nhwc.device, dtype=input_nhwc.dtype)
# ── TMA im2col descriptor for input ──────────────────────────────────
# This configures TMA's hardware im2col mode on the [N, H, W, Ci]
# input tensor. Key parameters:
#
# block_shape = [BLOCK_M, BLOCK_K]
# BLOCK_M output positions x BLOCK_K channels per async copy.
#
# element_strides = [1, stride, stride, 1]
# TMA steps by `stride` in H/W between consecutive output
# positions, matching the convolution stride.
#
# pixel_box_{lower,upper}_corner
# Defines the spatial window that TMA traverses per batch.
# The window must cover exactly out_h * out_w pixels:
# window_h = (out_h - 1) * stride + 1
# upper_h = window_h - H - padding
# Lower corner is [-padding, -padding].
#
# padding = "zero"
# Out-of-bounds reads (from conv padding) return 0 automatically.
#
upper_h = (out_h - 1) * stride + 1 - H - padding
upper_w = (out_w - 1) * stride + 1 - W - padding
input_block_shape = [BLOCK_M, BLOCK_K]
input_layout = ttgl.NVMMASharedLayout.get_default_for(input_block_shape, ttgl.float16)
in_desc = TensorDescriptorIm2Col.from_tensor(
input_nhwc,
input_block_shape,
input_layout,
padding="zero",
element_strides=[1, stride, stride, 1],
pixel_box_lower_corner=[-padding, -padding],
pixel_box_upper_corner=[upper_h, upper_w],
)
# ── TMA descriptor for weight ────────────────────────────────────────
# Reshape weight [Co, R, S, Ci] -> [Co, R*S*Ci] for standard 2D TMA
weight_2d = weight.reshape(Co, R * S * Ci)
weight_block_shape = [BLOCK_N, BLOCK_K]
weight_layout = ttgl.NVMMASharedLayout.get_default_for(weight_block_shape, ttgl.float16)
weight_desc = TensorDescriptor.from_tensor(weight_2d, weight_block_shape, weight_layout)
# ── TMA descriptor for output ────────────────────────────────────────
# Output is [M_GEMM, Co] (flattened spatial dims).
out_block_shape = [BLOCK_M, BLOCK_N]
out_layout = ttgl.NVMMASharedLayout.get_default_for(out_block_shape, ttgl.float16)
out_desc = TensorDescriptor.from_tensor(output, out_block_shape, out_layout)
# ── Launch kernel ────────────────────────────────────────────────────
grid = (triton.cdiv(M_GEMM, BLOCK_M), triton.cdiv(N_GEMM, BLOCK_N))
conv2d_im2col_kernel[grid](
in_desc,
weight_desc,
out_desc,
R,
S,
Ci,
out_h,
out_w,
padding,
padding,
stride,
stride,
MMAImpl=MMAImpl,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
num_warps=num_warps,
)
return output.reshape(N, out_h, out_w, Co)
Testing
We compare our TMA im2col convolution against PyTorch’s conv2d.
@pytest.mark.parametrize("N", [1, 4])
@pytest.mark.parametrize("H,W", [(16, 16)])
@pytest.mark.parametrize("Ci,Co", [(64, 64), (96, 96)])
@pytest.mark.parametrize("R,S", [(3, 3), (1, 1)])
@pytest.mark.parametrize("stride", [1, 2])
@pytest.mark.parametrize("padding", [0, 1])
@pytest.mark.skipif(not t7.is_hopper_or_newer(), reason="Requires Hopper or newer GPU (SM 9.x+)")
def test_conv2d_tma_im2col(N, H, W, Ci, Co, R, S, stride, padding):
torch.manual_seed(0)
x_nhwc = torch.randn(N, H, W, Ci, device="cuda", dtype=torch.float16)
w_co_r_s_ci = torch.randn(Co, R, S, Ci, device="cuda", dtype=torch.float16)
# Our kernel
triton_out = conv2d_tma_im2col(x_nhwc, w_co_r_s_ci, stride=stride, padding=padding)
# PyTorch reference (input NCHW, weight OIHW)
x_nchw = x_nhwc.permute(0, 3, 1, 2).contiguous()
w_oihw = w_co_r_s_ci.permute(0, 3, 1, 2).contiguous()
torch_out = torch.nn.functional.conv2d(x_nchw, w_oihw, stride=stride, padding=padding)
torch_out_nhwc = torch_out.permute(0, 2, 3, 1)
torch.testing.assert_close(triton_out, torch_out_nhwc, atol=1e-2, rtol=1e-2)
Summary
In this tutorial we learned:
TMA im2col fundamentals: How
TensorDescriptorIm2Colparameters (pixel_box_lower_corner,pixel_box_upper_corner,element_strides,padding) control access boundaries on NHWC tensors, including padded regions, offset shifts, and multi-batch wrapping.Convolution as GEMM: The im2col transformation converts convolution into a matrix multiplication by rearranging input patches into rows of a matrix.
Implicit GEMM: Instead of materializing the im2col matrix, we compute input addresses on-the-fly during the GEMM K-loop, saving memory.
TMA im2col: TMA hardware natively supports im2col address computation, eliminating register pressure from index arithmetic. We configure a
TensorDescriptorIm2Colwith convolution parameters (strides, padding, pixel box), and TMA handles the rest. This feature is available on both Hopper and Blackwell GPUs.Unified kernel via MMA abstraction: By using the
WGMMA/MMAv5abstraction from07-persistence.py, a single kernel implementation works on both Hopper and Blackwell. The abstraction handles:Accumulator allocation (registers on Hopper, tensor memory on Blackwell)
MMA barrier management
Async MMA issue and wait
The kernel itself only deals with TMA loads and the unified MMA API:
initialize,issue_async_mma,wait_num_outstanding,take_result.
For a production-quality implementation with warp specialization and
pipelining, see examples/gluon/02-convolution.py.
if __name__ == "__main__":
# Conv2d demo
torch.manual_seed(0)
N, H, W, Ci, Co, R, S = 4, 16, 16, 64, 64, 3, 3
stride, padding = 1, 1
x_nhwc = torch.randn(N, H, W, Ci, device="cuda", dtype=torch.float16)
w_co_r_s_ci = torch.randn(Co, R, S, Ci, device="cuda", dtype=torch.float16)
triton_out = conv2d_tma_im2col(x_nhwc, w_co_r_s_ci, stride=stride, padding=padding)
# Compare with PyTorch (input NCHW, weight OIHW)
x_nchw = x_nhwc.permute(0, 3, 1, 2).contiguous()
w_oihw = w_co_r_s_ci.permute(0, 3, 1, 2).contiguous()
torch_out = torch.nn.functional.conv2d(x_nchw, w_oihw, stride=stride, padding=padding)
torch_out_nhwc = torch_out.permute(0, 2, 3, 1)
max_err = (triton_out - torch_out_nhwc).abs().max().item()
print(f"Conv2d: N={N}, H={H}, W={W}, Ci={Ci}, Co={Co}, R={R}, S={S}, "
f"stride={stride}, pad={padding}")
print(f"Output shape: {list(triton_out.shape)}")
print(f"Max absolute error vs PyTorch: {max_err:.6f}")
print("PASSED!" if max_err < 0.02 else "FAILED!")
Conv2d: N=4, H=16, W=16, Ci=64, Co=64, R=3, S=3, stride=1, pad=1
Output shape: [4, 16, 16, 64]
Max absolute error vs PyTorch: 0.000000
PASSED!