TritonNvidiaGPUOps

ttng.async_tma_copy_global_to_local (triton::nvidia_gpu::AsyncTMACopyGlobalToLocalOp)

Copy data based on descriptor from global memory to local memory asynchronously

Syntax:

operation ::= `ttng.async_tma_copy_global_to_local` $desc_ptr `[` $coord `]` $result `,` $barrier `,` $pred
              oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict)
              attr-dict `:` qualified(type($desc_ptr)) `,` qualified(type($barrier)) `->` qualified(type($result))

This operation copies data from global memory to local memory asynchronously. This is analogue to tt.load except the data are copied to local memory pointed by the memory descriptor instead of a distributed tensor. The data copied depends on the global memory descriptor pointed to by desc_ptr.

Traits: VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
cache::mlir::triton::CacheModifierAttr
allowed 32-bit signless integer cases: 1, 2, 3, 4, 5, 6, 7{{% markdown %}}Enum cases: * none (`NONE`) * ca (`CA`) * cg (`CG`) * wb (`WB`) * cs (`CS`) * wt (`WT`) * cv (`CV`){{% /markdown %}}
evict::mlir::triton::EvictionPolicyAttr
allowed 32-bit signless integer cases: 1, 2, 3{{% markdown %}}Enum cases: * evict_normal (`NORMAL`) * evict_first (`EVICT_FIRST`) * evict_last (`EVICT_LAST`){{% /markdown %}}
isVolatile::mlir::BoolAttrbool attribute

Operands:

Operand

Description

desc_ptr

Pointer type (::mlir::triton::PointerType) in Triton IR type system

coord

variadic of 32-bit signless integer

barrier

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

result

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

pred

1-bit signless integer

ttng.async_tma_copy_local_to_global (triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp)

Copy data based on descriptor from local memory to global memory asynchronously

Syntax:

operation ::= `ttng.async_tma_copy_local_to_global` $desc_ptr `[` $coord `]` $src
              attr-dict `:` qualified(type($desc_ptr)) `,` qualified(type($src))

This operation copies data from local memory to global memory asynchronously. This is analogue to tt.store except the data are copied from local memory pointed by the memory descriptor instead of a distributed tensor. The data copied depends on the global memory descriptor pointed to by desc_ptr.

Traits: VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface

Operands:

Operand

Description

desc_ptr

Pointer type (::mlir::triton::PointerType) in Triton IR type system

coord

variadic of 32-bit signless integer

src

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

ttng.async_tma_gather (triton::nvidia_gpu::AsyncTMAGatherOp)

Gather data based on descriptor from global memory to local memory asynchronously

Syntax:

operation ::= `ttng.async_tma_gather` $desc_ptr `[` $x_offsets `,` $y_offset `]` $result `,` $barrier `,` $pred
              attr-dict `:` type(operands)

This operation gathers multiple rows of data from global memory matrix to local memory asynchronously. This is similar to async_tma_copy_global_to_local except that each row is indexed independently.

Traits: VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface

Operands:

Operand

Description

desc_ptr

Pointer type (::mlir::triton::PointerType) in Triton IR type system

x_offsets

ranked tensor of 32-bit signless integer values

y_offset

32-bit signless integer

barrier

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

result

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

pred

1-bit signless integer

ttng.async_tma_scatter (triton::nvidia_gpu::AsyncTMAScatterOp)

Scatter data from local memory into global memory based on a descriptor asynchronously

Syntax:

operation ::= `ttng.async_tma_scatter` $desc_ptr `[` $x_offsets `,` $y_offset `]` $src
              attr-dict `:` type(operands)

The ttng.async_tma_scatter operation scatters multiple separately-indexed rows of data from local memory into global memory asynchronously. The operation scatters a 2D tensor in shared memory, laid out by core tensor tiles nvmma_shared layout into separately indexed rows in global memory at a given y offset.

Traits: VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface

Operands:

Operand

Description

desc_ptr

Pointer type (::mlir::triton::PointerType) in Triton IR type system

x_offsets

ranked tensor of 32-bit signless integer values

y_offset

32-bit signless integer

src

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

ttng.barrier_expect (triton::nvidia_gpu::BarrierExpectOp)

Signal a barrier of an expected number of bytes to be copied.

Syntax:

operation ::= `ttng.barrier_expect` $alloc `,` $size attr-dict `,` $pred `:` qualified(type($alloc))

This signal the barrier that size bytes are expected to be copied. The associated barrier wait will block until the expected number of bytes are copied.

Traits: VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
size::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

alloc

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

pred

1-bit signless integer

ttng.cluster_arrive (triton::nvidia_gpu::ClusterArriveOp)

Syntax:

operation ::= `ttng.cluster_arrive` attr-dict

Traits: VerifyTensorLayoutsTrait

Attributes:

AttributeMLIR TypeDescription
relaxed::mlir::IntegerAttr1-bit signless integer attribute

ttng.cluster_wait (triton::nvidia_gpu::ClusterWaitOp)

Syntax:

operation ::= `ttng.cluster_wait` attr-dict

Traits: VerifyTensorLayoutsTrait

ttng.fence_async_shared (triton::nvidia_gpu::FenceAsyncSharedOp)

Fence proxy async

Syntax:

operation ::= `ttng.fence_async_shared` attr-dict

Traits: VerifyTensorLayoutsTrait

Attributes:

AttributeMLIR TypeDescription
bCluster::mlir::BoolAttrbool attribute

ttng.init_barrier (triton::nvidia_gpu::InitBarrierOp)

Initialize a barrier in the given shared memory allocation.

Syntax:

operation ::= `ttng.init_barrier` $alloc `,` $count attr-dict `:` qualified(type($alloc))

Initializes a shared memory allocation with mbarrier information. alloc is a descriptor to the shared memory allocation. count is the number of arrives expected by the barrier.

This lowers to PTX mbarrier.init.shared::cta.b64.

Traits: VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
count::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

alloc

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

ttng.inval_barrier (triton::nvidia_gpu::InvalBarrierOp)

Invalidate a barrier allocation.

Syntax:

operation ::= `ttng.inval_barrier` $alloc attr-dict `:` qualified(type($alloc))

Invalidate a barrier allocation so that it can be re-used. According to PTX spec this has to be done before any reuse of the memory used by mbarrier.

https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval

Traits: VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface

Operands:

Operand

Description

alloc

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

ttng.tc_gen5_mma (triton::nvidia_gpu::TCGen5MMAOp)

Block level op mapping to tensorcore gen5 mma

Syntax:

operation ::= `ttng.tc_gen5_mma` $a`,` $b`,` $d`,` $useD`,` $pred (`,` $barrier^)? attr-dict `:` functional-type(operands, results)

$d += matrix_multiply($a, $b). If not barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier. If there is a barrier the result will be safe to read after a barrier wait. If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs. and syncronize both CTAs if the op is synchronous.

Traits: VerifyTensorLayoutsTrait

Interfaces: DotOpInterface, MMAv5OpInterface, MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
two_ctas::mlir::UnitAttrunit attribute

Operands:

Operand

Description

a

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

b

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

d

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

useD

1-bit signless integer

pred

1-bit signless integer

barrier

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

ttng.tc_gen5_mma_scaled (triton::nvidia_gpu::TCGen5MMAScaledOp)

Block level op mapping to tensorcore gen5 mma

Syntax:

operation ::= `ttng.tc_gen5_mma_scaled` $a `,` $b `,` $d `,` $a_scale `,` $b_scale `,` $useD`,` $pred `lhs` `=` $a_type `rhs` `=` $b_type (`,` $barrier^)? attr-dict `:` functional-type(operands, results)

$d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale)) If not barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier. If there is a barrier the result will be safe to read after a barrier wait.

Traits: VerifyTensorLayoutsTrait

Interfaces: DotOpInterface, MMAv5OpInterface, MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
a_type::mlir::triton::ScaleDotElemTypeAttr
allowed 32-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6{{% markdown %}}Enum cases: * e4m3 (`E4M3`) * e5m2 (`E5M2`) * e2m3 (`E2M3`) * e3m2 (`E3M2`) * e2m1 (`E2M1`) * bf16 (`BF16`) * fp16 (`FP16`){{% /markdown %}}
b_type::mlir::triton::ScaleDotElemTypeAttr
allowed 32-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6{{% markdown %}}Enum cases: * e4m3 (`E4M3`) * e5m2 (`E5M2`) * e2m3 (`E2M3`) * e3m2 (`E3M2`) * e2m1 (`E2M1`) * bf16 (`BF16`) * fp16 (`FP16`){{% /markdown %}}

Operands:

Operand

Description

a

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

b

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

d

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

a_scale

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

b_scale

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

useD

1-bit signless integer

pred

1-bit signless integer

barrier

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

ttng.async_tma_store_wait (triton::nvidia_gpu::TMAStoreWaitOp)

Wait until all the inputs are read.

Syntax:

operation ::= `ttng.async_tma_store_wait` attr-dict

Wait until all the read operations are done from the associated store operations. This is needed before the shared memory can be written to.

Traits: VerifyTensorLayoutsTrait

Attributes:

AttributeMLIR TypeDescription
pendings::mlir::IntegerAttr32-bit signless integer attribute

ttng.tmem_alloc (triton::nvidia_gpu::TMEMAllocOp)

Allocate tensor memory

Syntax:

operation ::= `ttng.tmem_alloc` $src attr-dict `:` functional-type(operands, results)

This operation allocates buffer in tensor memory and return a descriptor containing the address and a view of the buffer. This is similar to ttg.local_alloc except the buffer is allocated in tensor memory.

Explicitly deallocating a buffer is optional; see local_dealloc.

Traits: VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}

Operands:

Operand

Description

src

ranked tensor of floating-point or integer or ptr values

Results:

Result

Description

result

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

ttng.tmem_copy (triton::nvidia_gpu::TMEMCopyOp)

Initiate an asynchronous copy operation from shared memory to the Tensor Memory.

Syntax:

operation ::= `ttng.tmem_copy` $src `,` $dst `,` $barrier attr-dict `:` functional-type(operands, results)

2D blocks stored contiguously in SMEM are copied into TMEM as specified by the destination address. The completion of the copy can be observed by waiting on the optional barrier. If this op is used together with an MMA op, one barrier can be used to wait for both copy and MMA. We do not need to wait for the completion of the copy before MMA, since tcgen05.cp followed by tcgen05.mma is guaranteed to execute in that order.

This op lowers to the PTX instruction tcgen05.cp. Right now, we only support 1CTA and the warpx4.32x128b variant of the instruction. Each 32x128b block in SMEM is duplicated over 4 warps and stored into 128 rows and 4 columns of TMEM. The primary use case of this op is to copy blocked scales from SMEM to TMEM.

The shape of the input SMEM can be flexibily chosen depending on use cases. In the simplest case (e.g. unit test), the source SMEM can be of shape (32 x num_blocks, 16), and the destination TMEM should be of shape (128, 16 x num_blocks), for copying 8 bit values. For scaled GEMM, rep_m x rep_k copies of a 32x128b block need to be stored in SMEM, where rep_m = BLOCK_M / 128, rep_k = BLOCK_K / scale_vec_size / 4, and scale_vec_size = 32 for MXFP. Conceptually, the SMEM is organized in a high-dimensional layout, (rep_m, rep_k, 32, 4, 4B). Some of axes can be flattened into one, to reduce the rank of the load. For example, the following patterns are supported:

  • (rep_m, rep_k * 32 x 4 x 4B), 2D scale load with cp.async

  • (rep_m, rep_k, 32, 16B), 4D scale load with TMA

  • (rep_m, rep_k, 32, 4, 4B), 5D scale load with cp.async Since rep_m blocks are not contiguous in SMEM, this axis cannot be flattened into inner ones.

In Triton, the TMEM memdesc for blocked scales must be of the following form:

  • Its shape must be (BLOCK_MN, BLOCK_K / scale_vec_size), representing the logical shape of blocked scales.

  • It must be attached with tensor_memory_scales_encoding to indicate the chunk-based layout and its duplication over 4 warps.

In contrast, the src SMEM must be in the explicit chunk-based layout as described above. So the IR might look like this:

ttng.tmem_copy %1, %0 : (!ttg.memdesc<1x1x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>) -> ()

We interpret the semantics of this copy operation as follows. The chunk-based layout in SMEM implies that the logical shape (BLOCK_MN, BLOCK_K / scale_vec_size) in TMEM is the result of certain reshape and transpose operations. In practice, to take an advantage of the native scale layout and the TMEM copy op, users need to do scales5D.trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K // scale_vec_size) before feeding scales into dot_scaled. When we use tmem_copy in the IR, such reshape and transpose operations are removed. But the change in the logical shape they have caused on registers is now understood to be incorporated into tmem_copy itself. Ideally, we would lift reshape / transpose done on registers onto the SMEM memdesc, making tmem_copy a straightforward 2D copy operation: (BLOCK_MN, BLOCK_K / scale_vec_size) -> (BLOCK_MN, BLOCK_K / scale_vec_size). In the absence of such operations on memdesc, we resort to implicitly encoding the reshape/transpose semantics in tmem_copy.

Traits: VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}

Operands:

Operand

Description

src

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

dst

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

barrier

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

ttng.tmem_load (triton::nvidia_gpu::TMEMLoadOp)

Load a buffer from tensor memory into a distributed tensor

Syntax:

operation ::= `ttng.tmem_load` $src attr-dict `:` qualified(type($src)) `->` type($result)

This is similar to ttg.local_load except the result layout is restricted to only few possibility. Therefore we cannot combine this op with any convert layout like local_load.

Traits: VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::SideEffects::DefaultResource}

Operands:

Operand

Description

src

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

ttng.tmem_store (triton::nvidia_gpu::TMEMStoreOp)

Store a distributed tensor into a buffer in tensor memory

Syntax:

operation ::= `ttng.tmem_store` $src `,` $dst `,` $pred attr-dict `:` type($src) `->` qualified(type($dst))

This is similar to ttg.local_local except the source layout is restricted to only few possibility.

Traits: VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::SideEffects::DefaultResource}

Operands:

Operand

Description

dst

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

src

ranked tensor of floating-point or integer or ptr values

pred

1-bit signless integer

ttng.tensor_desc_to_tma_ptr (triton::nvidia_gpu::TensorDescToTMAPtrOp)

Convert tensor descriptor to pointer to tma descriptor

Syntax:

operation ::= `ttng.tensor_desc_to_tma_ptr` $desc attr-dict `:` qualified(type($desc)) `to` qualified(type($ptr))

Traits: AlwaysSpeculatableImplTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

desc

Tensor descriptor type (::mlir::triton::TensorDescType) in Triton IR type system

Results:

Result

Description

ptr

ptr

ttng.wait_barrier (triton::nvidia_gpu::WaitBarrierOp)

Wait until the mbarrier phase completes.

Syntax:

operation ::= `ttng.wait_barrier` $alloc `,` $phase attr-dict (`,` $pred^)? `:` qualified(type($alloc))

Blocks the program progress until the mbarrier object in alloc completes its current phase.

This lowers a waitloop using PTX instruction mbarrier.try_wait.parity.shared.b64.

The barrier behavior is described here: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-asynchronous-copy-completion-mechanisms

Traits: VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface

Operands:

Operand

Description

alloc

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

phase

32-bit signless integer

pred

1-bit signless integer

ttng.warp_group_dot (triton::nvidia_gpu::WarpGroupDotOp)

Warp group dot

Syntax:

operation ::= `ttng.warp_group_dot` $a`,` $b`,` $c (`,` $useC^)? attr-dict `:` type($a) `*` type($b) `->` type($d)

$d = matrix_multiply($a, $b) + $c. For docs on InputPrecisionAttr, see TT_DotOp

Traits: VerifyTensorLayoutsTrait

Interfaces: DotOpInterface, InferTypeOpInterface, MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
inputPrecision::mlir::triton::InputPrecisionAttr
allowed 32-bit signless integer cases: 0, 1, 2{{% markdown %}}Enum cases: * tf32 (`TF32`) * tf32x3 (`TF32x3`) * ieee (`IEEE`){{% /markdown %}}
maxNumImpreciseAcc::mlir::IntegerAttr32-bit signless integer attribute
isAsync::mlir::BoolAttrbool attribute

Operands:

Operand

Description

a

TensorOrMemDesc instance

b

TensorOrMemDesc instance

c

ranked tensor of floating-point or integer values

useC

1-bit signless integer

Results:

Result

Description

d

ranked tensor of floating-point or integer values

ttng.warp_group_dot_wait (triton::nvidia_gpu::WarpGroupDotWaitOp)

Warp group dot wait

Syntax:

operation ::= `ttng.warp_group_dot_wait` $inputs attr-dict `:` type($inputs)

Waits until there are $pendings or fewer outstanding async dot operations.

$inputs must be the tensors corresponding to the async dot ops that we’re waiting on. For example, if there are N pending async dot ops and we call warp_group_dot_wait 1, then $inputs must be the result of the first dot op.

Traits: VerifyTensorLayoutsTrait

Interfaces: InferTypeOpInterface

Attributes:

AttributeMLIR TypeDescription
pendings::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

inputs

variadic of TensorOrMemDesc instance

Results:

Result

Description

outputs

variadic of TensorOrMemDesc instance