TritonAMDGPUOps

amdg.arrive_barrier (triton::amdgpu::ArriveBarrierOp)

Perform the arrive operation on an mbarrier

Syntax:

operation ::= `amdg.arrive_barrier` $alloc `,` $count attr-dict `:` qualified(type($alloc)) `->` type($result)

Performs the “arrive” operation on an mbarrier object in shared memory. The operation requires a count attribute of at least 1, and decreases the pending arrival count of the mbarrier by the specific count. If the pending count reaches zero, the phase changes (is decremented in a wraparound manner) and the pending count is reloaded with the init count value. Returns the phase of the mbarrier object prior to the “arrive” operation.

Example:

ttag.arrive_barrier %barrier, 2 : !ttg.memdesc<1xi64, #shared, #smem, mutable>

Interfaces: InferTypeOpInterface

Attributes:

AttributeMLIR TypeDescription
count::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

alloc

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

Results:

Result

Description

result

32-bit signless integer

amdg.async_copy_mbarrier_arrive (triton::amdgpu::AsyncCopyMbarrierArriveOp)

Arrive on mbarrier once all previously issued copies are completed

Syntax:

operation ::= `amdg.async_copy_mbarrier_arrive` $barrier attr-dict `:` qualified(type($barrier))

Performs the “async arrive” operation by decrementing pending account by 1 when all previous async load to LDS (particularly, not TDM) have completed. The instruction itself is asynchronous; it returns immediately. Decrements the barrier pending count. The update value for decrementing is fixed at 1. If the pending count becomes zero, the phase changes (is decremented in a wraparound manner) and the pending count is reloaded with the init count value.

Operands:

Operand

Description

barrier

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

amdg.async_tdm_copy_global_to_local (triton::amdgpu::AsyncTDMCopyGlobalToLocalOp)

Copy data based on descriptor from global memory to local memory asynchronously

Syntax:

operation ::= `amdg.async_tdm_copy_global_to_local` $desc `[` $indices `]` `into` $result `,` $pred (`,` `barrier` `=` $barrier^)?
              attr-dict `:` qualified(type($desc)) (`,` qualified(type($barrier))^)? `->` qualified(type($result))

This operation copies data from global memory to local memory asynchronously. This is analogue to tt.load except the data are copied to local memory pointed by result instead of a distributed tensor. The data copied depends on the global memory pointed to by desc. Set pred to false will disable the copy. This operation does not support shared memory swizzling. The operation can also take an optional 64bit LDS barrier address, in which case it sends an “LDS atomic arrive” to signal its completion.

Traits: AttrSizedOperandSegments

Interfaces: InferTypeOpInterface

Operands:

Operand

Description

desc

Tensor descriptor type (::mlir::triton::TensorDescType) in Triton IR type system

indices

variadic of 32-bit signless integer

result

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

pred

1-bit signless integer

barrier

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

Results:

Result

Description

token

async token type

amdg.async_tdm_copy_local_to_global (triton::amdgpu::AsyncTDMCopyLocalToGlobalOp)

Copy data based on descriptor from local memory to global memory asynchronously

Syntax:

operation ::= `amdg.async_tdm_copy_local_to_global` $desc `[` $indices `]` `from` $src
              attr-dict `:` qualified(type($src)) `->` qualified(type($desc))

This operation copies data from local memory to global memory asynchronously. This is analogue to tt.store except the data are copied from local memory pointed by src instead of a distributed tensor. The copy destination depends on the global memory pointed to by desc. This operation does not support shared memory padding or swizzling.

Operands:

Operand

Description

desc

Tensor descriptor type (::mlir::triton::TensorDescType) in Triton IR type system

indices

variadic of 32-bit signless integer

src

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

amdg.async_tdm_wait (triton::amdgpu::AsyncTDMWait)

Wait until there are less than or equal to the given number of outstanding TDM operations

Syntax:

operation ::= `amdg.async_tdm_wait` $asyncToken attr-dict

This operation waits until there are less than or equal to the given number of outstanding TDM operations, including both loads and stores. This is necessary to ensure that data is available in the LDS before it is used.

Traits: MemWaitOpTrait

Interfaces: InferTypeOpInterface

Attributes:

AttributeMLIR TypeDescription
num::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

asyncToken

variadic of async token type

Results:

Result

Description

retToken

async token type

amdg.async_wait (triton::amdgpu::AsyncWaitOp)

Wait until there are less than or equal to the given number of outstanding async intrinsics

Syntax:

operation ::= `amdg.async_wait` ($asyncToken^)? attr-dict

Similar to ttg.async_wait but instead of waiting on oustanding ttg.async_commit_groups this op waits on the number of outstanding async instructions/intrinsics as required for the lowering to LLVM on the AMD backend.

Traits: MemWaitOpTrait

Interfaces: InferTypeOpInterface

Attributes:

AttributeMLIR TypeDescription
num_inst::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

asyncToken

variadic of async token type

Results:

Result

Description

retToken

async token type

amdg.buffer_atomic_cas (triton::amdgpu::BufferAtomicCASOp)

Atomic CAS op which does compare-exchange to a scalar base pointer and a tensor offset

Syntax:

operation ::= `amdg.buffer_atomic_cas` $sem `,` $scope `,` $cmp `,` $val `,` $ptr `[` $offsets `]`
              (`stride` `=` $stride^)?
              attr-dict `:` type($result)

AMD Buffer Atomic CAS operation. Buffer atomics are similar to normal atomics, but access global memory via a scalar base pointer and a tensor of offsets instead of a tensor of pointers. Similar to TT_AtomicCASOp: Buffer atomic CAS op loads data at $ptr, and stores $val to $ptr atomically if value at $ptr equals $cmp, with the specified memory semantics and scope. Atomic CAS ops return the pre-op value if used, otherwise the value is implicitly dropped. Stride is the distance between the beginning of contiguous memory chunks. When performing a CAS, the stride is the address difference between the first elements of each row in bytes. Compiler tries to obtain the stride when it converts to the buffer ops because it is important for optimizing the cache memory access.

Traits: SameLoadStoreOperandsAndResultEncoding

Interfaces: BufferOpInterface

Attributes:

AttributeMLIR TypeDescription
sem::mlir::triton::MemSemanticAttrallowed 32-bit signless integer cases: 1, 2, 3, 4
scope::mlir::triton::MemSyncScopeAttrallowed 32-bit signless integer cases: 1, 2, 3

Operands:

Operand

Description

ptr

ptr

offsets

tensor of 32-bit signless integer values

cmp

ranked tensor of floating-point or integer or ptr values

val

ranked tensor of floating-point or integer or ptr values

stride

32-bit signless integer

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

amdg.buffer_atomic_rmw (triton::amdgpu::BufferAtomicRMWOp)

Atomic RMW op which reads, modifies, and writes to a scalar base pointer and a tensor offset

Syntax:

operation ::= `amdg.buffer_atomic_rmw` $atomic_rmw_op `,` $sem `,` $scope `,` $value `,` $ptr `[` $offsets `]` (`,` $mask^)?
              (`stride` `=` $stride^)?
              attr-dict `:` type($result)

AMD Buffer atomic RMW operation. Buffer atomics are similar to normal atomics, but access global memory via a scalar base pointer and a tensor of offsets instead of a tensor of pointers. Similar to other buffer ops, the mask is a boolean vector that determines if a given element should be processed with the atomic RMW op. Elements with mask[i] == 0 are dropped (i.e., the atomic is not executed). Similar to TT_AtomicRMWOp: Buffer atomic RMW ops load data at $ptr, do $rmw_op with $val, and store result to $ptr with the specified memory semantics and scope. Atomic RMW ops return the pre-op value if used, otherwise the value is implicitly dropped. Stride is the distance between the beginning of contiguous memory chunks. When performing a RMW, the stride is the address difference between the first elements of each row in bytes. Compiler tries to obtain the stride when it converts to the buffer ops because it is important for optimizing the cache memory access.

Traits: AttrSizedOperandSegments, SameLoadStoreOperandsAndResultEncoding

Interfaces: BufferOpInterface

Attributes:

AttributeMLIR TypeDescription
atomic_rmw_op::mlir::triton::RMWOpAttrallowed 32-bit signless integer cases: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
sem::mlir::triton::MemSemanticAttrallowed 32-bit signless integer cases: 1, 2, 3, 4
scope::mlir::triton::MemSyncScopeAttrallowed 32-bit signless integer cases: 1, 2, 3

Operands:

Operand

Description

ptr

ptr

offsets

tensor of 32-bit signless integer values

value

ranked tensor of floating-point or integer or ptr values

stride

32-bit signless integer

mask

ranked tensor of 1-bit signless integer values

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

amdg.buffer_load (triton::amdgpu::BufferLoadOp)

Load from a scalar base pointer and a tensor offset

Syntax:

operation ::= `amdg.buffer_load` $ptr `[` $offsets `]` (`,` $mask^)? (`,` $other^)?
              oilist(`cacheModifier` `=` $cache)
              (`stride` `=` $stride^)?
              attr-dict `:` type($result)

AMD Buffer load operation. Buffer store is similar to a normal store but it accesses global memory via a scalar base pointer and a tensor of offsets instead of a tensor of pointers. The other fields are similar to a normal load, i.e., the mask is a boolean vector that determines if a given element should be read from memory, and other is the element that should be returned on lane i when mask[i] == 0. Stride is the distance between the beginning of contiguous memory chunks. When performing a load of a block, the stride is the address difference between the first elements of each row in bytes. Compiler tries to obtain the stride when it converts to the buffer ops because it is important for optimizing the cache memory access.

Traits: AttrSizedOperandSegments, SameLoadStoreOperandsAndResultEncoding

Interfaces: BufferOpInterface

Attributes:

AttributeMLIR TypeDescription
cache::mlir::triton::CacheModifierAttrallowed 32-bit signless integer cases: 1, 2, 3, 4, 5, 6, 7

Operands:

Operand

Description

ptr

ptr

offsets

tensor of 32-bit signless integer values

stride

32-bit signless integer

mask

ranked tensor of 1-bit signless integer values

other

ranked tensor of floating-point or integer or ptr values

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

amdg.buffer_load_to_local (triton::amdgpu::BufferLoadToLocalOp)

Load from a scalar base pointer and a tensor offset to shared memory

Syntax:

operation ::= `amdg.buffer_load_to_local` $ptr `[` $offsets `]` (`mask` `=` $mask^)? (`other` `=` $other^)? (`stride` `=` $stride^)?
              oilist(`cacheModifier` `=` $cache) `into` $dest
              attr-dict `:` type($ptr) `[` type($offsets) `]` type($other) `->` type($dest)

AMD Buffer load operation. Similar to amdg.buffer_load op but directly wirtes to shared memory instead of into registers. Contiguity is the maximum number of elements that can be loaded in a single vector with the given layout and mask. This allows to use buffer_load_to_local even if the alignment cannot be proven based on IR.

Traits: AttrSizedOperandSegments

Interfaces: BufferOpInterface, InferTypeOpInterface

Attributes:

AttributeMLIR TypeDescription
cache::mlir::triton::CacheModifierAttrallowed 32-bit signless integer cases: 1, 2, 3, 4, 5, 6, 7
contiguity::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

dest

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

ptr

ptr

offsets

tensor of 32-bit signless integer values

mask

ranked tensor of 1-bit signless integer values

other

ranked tensor of floating-point or integer or ptr values

stride

32-bit signless integer

Results:

Result

Description

token

async token type

amdg.buffer_store (triton::amdgpu::BufferStoreOp)

Store into scalar base pointer and a tensor offset

Syntax:

operation ::= `amdg.buffer_store` $value `,` $ptr `[` $offsets `]` (`,` $mask^)?
              oilist(`cacheModifier` `=` $cache)
              (`stride` `=` $stride^)?
              attr-dict `:` type($value)

AMD Buffer store operation. Buffer store is similar to normal store but it accesses global memory via a scalar base pointer and a tensor of offsets instead of a tensor of pointers. The other fields are similar to a normal store , i.e., the mask is a boolean vector that determines if a given element should be written to memory, and value is the tensor of elements that should be written on lane i when mask[i] == 1. Stride is the distance between the beginning of contiguous memory chunks. When performing a block store, the stride is the address difference between the first elements of each row in bytes. Compiler tries to obtain the stride when it converts to the buffer ops because it is important for optimizing the cache memory access.

Traits: AttrSizedOperandSegments, SameLoadStoreOperandsEncoding

Interfaces: BufferOpInterface

Attributes:

AttributeMLIR TypeDescription
cache::mlir::triton::CacheModifierAttrallowed 32-bit signless integer cases: 1, 2, 3, 4, 5, 6, 7

Operands:

Operand

Description

value

ranked tensor of floating-point or integer or ptr values

ptr

ptr

offsets

tensor of 32-bit signless integer values

stride

32-bit signless integer

mask

ranked tensor of 1-bit signless integer values

amdg.concat (triton::amdgpu::ConcatOp)

Concat operation

Syntax:

operation ::= `amdg.concat` $sources attr-dict `:` type($sources) `->` type($result)

The “concat” operation combines a list of source n-dimensional tensors into a single larger destination tensor.

All source tensors must have the same shape, element type, and encoding. The concatenation dimension is inferred from the source and destination shapes provided by the user. For example, two tensors of shape 64x128 can produce a destination shape of 128x128, indicating concatenation along dimension 0; or 64x256, indicating concatenation along dimension 1.

Generally, source tensors passed as op arguments can be arranged into the resulting shape in multiple ways. For example, given four tensors of shape 64x64: concat s0<64x64>, s1<64x64>, s2<64x64>, s3<64x64> -> <128x128>

They can be laid out in different configurations within the result tensor:

  1. s0 s1 2) s0 s2 s2 s3 s1 s3

From a logical tensor perspective, the source tensors are treated as elements of a tensor of tensors. In other words, the 1-D array of input tensors is conceptually reshaped into an n-D grid. The semantics of this op assume a row-major order (or its n-D generalization), meaning the fastest-varying dimension is filled first, and the slowest-varying dimension is filled last. In the example above, this corresponds to layout 1).

The source and destination tensors must have identical linear layouts at the CTA tile level. That is, all base vectors for input dimensions must match, except for the register input dimension. The register basis must align on the subset that defines the logical tensor shape of a single CTA tile.

This ensures that the concatenation is a no-op, meaning no data rearrangement among threads is required to assemble the destination tensor with the given shape and layout. However, the order of CTA tiles within the layout does not need to match between source and destination layouts. It is the responsibility of the op’s lowering logic to handle this correctly.

This op is designed to work on logical tensors directly, avoiding the need for complex layout reinterpretation or reshaping. For example, the tt.join operation only supports concatenation along the innermost dimension, and requires that the resulting innermost dimension provide 2 elements per thread, distributed across registers. In contrast, this concat op imposes no constraints on the concatenation dimension or the size of dimensions.

  • sources: a list of the input tensors.

Example 1:

#blocked = #ttg.blocked<{sizePerThread = [1, 8],
    threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
%0 = amdg.concat %arg0, %arg1: tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>,
  -> tensor<64x64xf32, #blocked>

Example 2:

#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
#dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}>
%0 = amdg.concat %arg0, %arg1, %arg2, %arg3 : tensor<128x128xf16, #src_layout>, tensor<128x128xf16, #src_layout>, tensor<128x128xf16, #src_layout>,
                                                tensor<128x128xf16, #src_layout> -> tensor<256x256xf16, #dst_layout>

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

sources

variadic of ranked tensor of floating-point or integer or ptr values

Results:

Result

Description

result

ranked tensor of any type values

amdg.cond_barrier (triton::amdgpu::CondBarrierOp)

Conditionally set barriers to synchronize partial threads in a block

Syntax:

operation ::= `amdg.cond_barrier` $pred attr-dict

condBarrierOp sets barrier instruction only when the given argument is true. This provides a way to synchronize partial threads in a block, deliberately diverges the execution sequences. However, user should guarantee all threads converge at the end by calling condBarrierOp(true) with the remaining threads. Conceptually, this is similar to having an execution barrier inside an if statement. This op allows us to avoid blocking the whole block when suitable to help scheduling. NB. This doesn’t set any memory fence.

Operands:

Operand

Description

pred

1-bit signless integer

amdg.extract_slice (triton::amdgpu::ExtractSliceOp)

Extract slice operation

Syntax:

operation ::= `amdg.extract_slice` $source $static_offsets attr-dict `:` type($source) `to` type($result)

The “extract_slice” operation enables extracting a slice of a tensor in registers.

The “extract_slice” operation supports the following arguments:

  • source: the base tensor on which to create a view tensor

  • offsets: offsets into the base tensor at which to create the view

In distributed layouts, tensors are divided into CTA tiles. A CTA tile represents the smallest contiguous portion of a tensor that is distributed across all threads and warps within a workgroup. The ExtractSlice operation extracts a portion of the tensor that is a multiple of CTA tiles.

The source and destination must have matching linear layouts at the CTA tile level. This ensures that the extract_slice is a no-op, meaning no data rearrangement between threads is required to extract the destination tensor with the given shape and layout.

+——-+——-+ | W0 | W1 | | | | | + | + | | W2 | W3 | <– Single CTA tile (distributed across warps W0-W3) | | | | + | + | | | | +——-+——-+ | Source Tensor Extracted Slice | . +————–+ | . | W0 | W1 | | . | | | | | + | + | | | W2 | W3 | | | | | | | + | + | | | | | | +——-+——+ | | W0 | W1 | | | | | | | + | + | | | W2 W3 | | | | | | | + | + | | | | | | +————–+

This op is designed to work on logical tensors directly, avoiding the need for complex layout reinterpretation or reshaping. For example, the tt.split operation only supports splitting along the innermost dimension, and requires that the resulting innermost dimension provide 2 elements per thread, distributed across registers. In contrast, extract_slice op imposes no constraints on the extraction dimension or the size of dimensions.

Example 1:

#blocked = #ttg.blocked<{sizePerThread = [1, 8],
    threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [0, 1]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8],
    threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}>
%1 = ttg.convert_layout %0 : tensor<128x128xf16, #blocked>
    -> tensor<128x128xf16, #blocked1>
// create a slice of base tensor %1 with static offsets
%2 = amdg.extract_slice %0 [0, 0] :
  tensor<128x128xf16, #blocked1> to tensor<128x32xf16, #blocked1>

Example 1 shows how “extract_slice” operation may be used. In this example a new slice of 128x32 is created. “extract_slice” works on tensors where the desired slice has the same layout on a CTA tile as the source tensor. “%0” cannot be sliced directly as the resulting slice does not satisfy this condition. Therefore it needs to be converted to a layout suitable for slicing. “#blocked1” layout is appropriate for this as it keeps the sizePerThread the same thus keeping coalescing properties the same. In order to utilize all threads in a warp, “threadsPerWarp” is set to [16,4] for this new layout. This layout conversion carried out before using “extract_slice” ensures slicing still uses all threads efficiently. The size of the slice is determined by the result type.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
static_offsets::mlir::DenseI64ArrayAttri64 dense array attribute

Operands:

Operand

Description

source

ranked tensor of any type values

Results:

Result

Description

result

ranked tensor of any type values

amdg.in_thread_transpose (triton::amdgpu::InThreadTransposeOp)

Perform transpose of register values belonging to each threads

Syntax:

operation ::= `amdg.in_thread_transpose` $src attr-dict `:` type($src) `->` type($result)

This operation performs a layout transpose over values in registers per thread. Specifically, given the input layout’s blocked layout, it transposes the two last dimensions(rank-1 and rank-2) along the register dimension of the underlying linear layout.

Conversion example:

  • input layout: blocked layout with sizePerThread=[2, 2], order=[0, 1]. It’s linear layout register bases = [[1, 0], [2, 0], [0, 1], [0, 2]]

  • output layout: same thread and warp bases as in input, register bases = [[0, 1], [0, 2], [1, 0], [2, 0]]

This operation enables efficient coalesced loading from HBM with following vectorized writing to shared memory in cases when HBM and shared memory order differ and target AMD hardware does not natively support this transposition. This is a specific variant of ttg.convert_layout and will be converted to ttg.convert_layout when lowering to llvm. We do not want this conversion to be optimized out, because we need to explicitly materialize instructions to transpose within each thread after loading from HBM and before writing to shared memory.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

src

ranked tensor of floating-point or integer or ptr values

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

amdg.init_barrier (triton::amdgpu::InitBarrierOp)

Initialize a barrier in the given shared memory allocation.

Syntax:

operation ::= `amdg.init_barrier` $alloc `,` $count attr-dict `:` qualified(type($alloc))

Initializes a shared memory allocation with mbarrier information. alloc is a descriptor to the shared memory allocation. count is the number of arrives expected by the barrier.

Attributes:

AttributeMLIR TypeDescription
count::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

alloc

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

amdg.instruction_sched_hint (triton::amdgpu::InstructionSchedHint)

A placeholder op for instruction scheduling hints within a basic block

Syntax:

operation ::= `amdg.instruction_sched_hint` attr-dict

A placeholder op for instruction scheduling hints applied to instructions within a basic block where the placeholder op is located. This op is primarily intended to be used to adjust instruction scheduling inside the resulting main loop of a tt.dot operation. It’s easier to identify dot ops at a high level and, thus, to mark intended scheduling regions. The hint ops are eventually lowered into LLVM AMDGPU instruction scheduling primitives, which are meant to control how different kinds of instructions (valu/mfma, global/shared memory, etc.) should interleave for better instruction level parallelism.

Attributes:

AttributeMLIR TypeDescription
variant::mlir::triton::amdgpu::SchedHintAttrInstruction Scheduling Hints for AMD GPUs

amdg.local_load_packed_tranposed (triton::amdgpu::LocalLoadPackedTransposedOp)

Load a transposed packed tensor from shared memory into a distributed tensor

Syntax:

operation ::= `amdg.local_load_packed_tranposed` $src (`token` $token^)? attr-dict `:` qualified(type($src)) `->` type($result)

Requires a M/N packed and M/N contiguous tensor in shared memory and will yield a K packed K contiguous tensor in registers. The packing change will change the shape of the tensor by doubling the M/N dimension and halving the K dimension. For example if A is 16x64 in shared memory, the result of this operation will be 32x32.

Traits: LocalLoadTrait

Operands:

Operand

Description

src

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

token

async token type

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

amdg.masked_load (triton::amdgpu::MaskedLoadOp)

Masked load operation

Syntax:

operation ::= `amdg.masked_load` $ptr `,` $mask `,` $falseVal (`,` $multicastMask^)?
              oilist(`cacheModifier` `=` $cache)
              (`forceNoAlias` $forceNoAlias^)?
              attr-dict `:` functional-type(operands, results)

Load operation with masking and multicast support. If the mask is true, loads from the given pointer. Works with LLVM types as a utility op for making LLVM conversion easier. On architectures supporting multicast, the multicastMaskspecifies which CTAs in the cluster request the same data. This allows the hardware to efficiently broadcast the data to multiple CTAs in the cluster.

Attributes:

AttributeMLIR TypeDescription
cache::mlir::triton::CacheModifierAttrallowed 32-bit signless integer cases: 1, 2, 3, 4, 5, 6, 7
forceNoAlias::mlir::BoolAttrbool attribute

Operands:

Operand

Description

ptr

LLVM pointer type

mask

1-bit signless integer

falseVal

LLVM dialect-compatible type

multicastMask

16-bit signless integer

Results:

Result

Description

result

LLVM dialect-compatible type

amdg.masked_store (triton::amdgpu::MaskedStoreOp)

Masked Store operation

Syntax:

operation ::= `amdg.masked_store` $ptr `,` $value `,` $mask
              oilist(`cacheModifier` `=` $cache)
              (`forceNoAlias` $forceNoAlias^)?
              attr-dict `:` type(operands)

Store operation with masking support. If the mask is true, Store from the given pointer. Works with LLVM types as a utility op for making LLVM conversion easier.

Attributes:

AttributeMLIR TypeDescription
cache::mlir::triton::CacheModifierAttrallowed 32-bit signless integer cases: 1, 2, 3, 4, 5, 6, 7
forceNoAlias::mlir::BoolAttrbool attribute

Operands:

Operand

Description

ptr

LLVM pointer type

value

LLVM dialect-compatible type

mask

1-bit signless integer

amdg.memory_counter_wait (triton::amdgpu::MemoryCounterWaitOp)

Wait for specified hardware counters

Syntax:

operation ::= `amdg.memory_counter_wait` oilist( `load` `(` $load `)` | `store` `(` $store `)` | `ds` `(` $ds `)` ) attr-dict

Wait for the specified counters to be less-than or equal-to the provided values before continuing.

Counters can lower to different instructions on different architectires, including clamping to the some HW supported max value or combining multiple counters into one.

Attributes:

AttributeMLIR TypeDescription
load::mlir::IntegerAttr32-bit signless integer attribute
store::mlir::IntegerAttr32-bit signless integer attribute
ds::mlir::IntegerAttr32-bit signless integer attribute

amdg.scaled_upcast_fp4 (triton::amdgpu::ScaledUpcastFp4Op)

Upcast fp4 and then multiply scale

Syntax:

operation ::= `amdg.scaled_upcast_fp4` $input `scale` $scale attr-dict
              `:` type($input) `,` type($scale) `->` type($output)

Upcast fp4 (e2m1) values packed as i8 values and multiply with the given E8M0 scale encoded as BF16. This maps to v_cvt_scalef32_* intrinsics on the AMD CDNA4 architecture.

The lower 4 bits of the i8s represent the first fp4 element, and the upper 4 bits the second fp4 element.

The axis attribute specifies the axis along which the fp4 elements are packed.

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), UpcastFpOpInterface

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

input

ranked tensor of 8-bit signless integer values

scale

ranked tensor of bfloat16 type values

Results:

Result

Description

output

ranked tensor of 16-bit float or bfloat16 type or 32-bit float values

amdg.scaled_upcast_fp8 (triton::amdgpu::ScaledUpcastFp8Op)

Upcast Fp8 and then multiply scale

Syntax:

operation ::= `amdg.scaled_upcast_fp8` $input `scale` $scale attr-dict
              `:` type($input) `,` type($scale) `->` type($output)

Upcast fp8 (e4m3/e5m2) values and multiply with the given E8M0 scale encoded as BF16. This maps to v_cvt_scalef32_* intrinsics on the AMD CDNA4 architecture.

Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultEncoding, SameOperandsAndResultShape

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), UpcastFpOpInterface

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

input

ranked tensor of f8E4M3FN type or f8E5M2 type values

scale

ranked tensor of bfloat16 type values

Results:

Result

Description

output

ranked tensor of 16-bit float or bfloat16 type or 32-bit float values

amdg.upcast_mxfp (triton::amdgpu::UpcastMXFPOp)

Convert an mxfp tensor to bf16/fp16

Syntax:

operation ::= `amdg.upcast_mxfp` $src `,` $scale  `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)

Compute the bf16 encoded in the given mxfp number as per https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

Traits: AlwaysSpeculatableImplTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
fp_type::mlir::triton::ScaleDotElemTypeAttrallowed 32-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6
fastMath::mlir::BoolAttrbool attribute

Operands:

Operand

Description

src

ranked tensor of floating-point or integer or ptr values

scale

ranked tensor of floating-point or integer or ptr values

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

amdg.wait_barrier (triton::amdgpu::WaitBarrierOp)

Wait until the mbarrier phase completes.

Syntax:

operation ::= `amdg.wait_barrier` $alloc `,` $phase attr-dict `:` qualified(type($alloc))

Blocks the program progress until the mbarrier object in alloc completes its current phase.

Operands:

Operand

Description

alloc

memory descriptor type (::mlir::triton::gpu::MemDescType) in Triton IR type system

phase

32-bit signless integer