TritonOps

tt.call (triton::CallOp)

Call operation

Syntax:

operation ::= `tt.call` $callee `(` $operands `)` attr-dict `:` functional-type($operands, results)

The tt.call operation represents a direct call to a function that is within the same symbol scope as the call. The operands and result types of the call must match the specified function type. The callee is encoded as a symbol reference attribute named “callee”.

Example:

%2 = tt.call @my_add(%0, %1) : (f32, f32) -> f32

Traits: TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: CallOpInterface, SymbolUserOpInterface

Attributes:

AttributeMLIR TypeDescription
callee::mlir::FlatSymbolRefAttrflat symbol reference attribute

Operands:

Operand

Description

operands

variadic of any type

Results:

Result

Description

«unnamed»

variadic of any type

tt.func (triton::FuncOp)

An operation with a name containing a single SSACFG region

Operations within the function cannot implicitly capture values defined outside of the function, i.e. Functions are IsolatedFromAbove. All external references must use function arguments or attributes that establish a symbolic connection (e.g. symbols referenced by name via a string attribute like SymbolRefAttr). An external function declaration (used when referring to a function declared in some other module) has no body. While the MLIR textual form provides a nice inline syntax for function arguments, they are internally represented as “block arguments” to the first block in the region.

Only dialect attribute names may be specified in the attribute dictionaries for function arguments, results, or the function itself.

Example:

// External function definitions.
tt.func @abort()
tt.func @scribble(i32, i64, memref<? x 128 x f32, #layout_map0>) -> f64

// A function that returns its argument twice:
tt.func @count(%x: i64) -> (i64, i64)
  attributes {fruit: "banana"} {
  return %x, %x: i64, i64
}

// A function with an argument attribute
tt.func @example_fn_arg(%x: i32 {swift.self = unit})

// A function with a result attribute
tt.func @example_fn_result() -> (f64 {dialectName.attrName = 0 : i64})

// A function with an attribute
tt.func @example_fn_attr() attributes {dialectName.attrName = false}

Traits: AffineScope, AutomaticAllocationScope, IsolatedFromAbove, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: CallableOpInterface, FunctionOpInterface, OpAsmOpInterface, Symbol

Attributes:

AttributeMLIR TypeDescription
sym_name::mlir::StringAttrstring attribute
function_type::mlir::TypeAttrtype attribute of function type
sym_visibility::mlir::StringAttrstring attribute
arg_attrs::mlir::ArrayAttrArray of dictionary attributes
res_attrs::mlir::ArrayAttrArray of dictionary attributes

tt.return (triton::ReturnOp)

Function return operation

Syntax:

operation ::= `tt.return` attr-dict ($srcs^ `:` type($srcs))?

The tt.return operation represents a return operation within a function. The operation takes variable number of operands and produces no results. The operand number and types must match the signature of the function that contains the operation.

Example:

tt.func @foo() : (i32, f8) {
  ...
  tt.return %0, %1 : i32, f8
}

Traits: AlwaysSpeculatableImplTrait, HasParent<FuncOp>, ReturnLike, TensorSizeTrait, Terminator, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), RegionBranchTerminatorOpInterface

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

srcs

variadic of any type

tt.addptr (triton::AddPtrOp)

Syntax:

operation ::= `tt.addptr` $ptr `,` $offset attr-dict `:` type($result) `,` type($offset)

Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultEncoding, SameOperandsAndResultShape, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

ptr

ptr or ranked tensor of ptr values

offset

integer or ranked tensor of integer values

Results:

Result

Description

result

ptr or ranked tensor of ptr values

tt.advance (triton::AdvanceOp)

Advance a tensor pointer by offsets

Syntax:

operation ::= `tt.advance` $ptr `,` `[` $offsets `]` attr-dict `:` type($result)

Traits: AlwaysSpeculatableImplTrait, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

ptr

ptr

offsets

variadic of 32-bit signless integer

Results:

Result

Description

result

ptr

tt.assert (triton::AssertOp)

Device-side assert, as in CUDA for correctness checking

Syntax:

operation ::= `tt.assert` $condition `,` $message `,` $file `,` $func `,` $line attr-dict `:` type($condition)

tt.assert takes a condition tensor, a message string, a file string, a function string, and a line number. If the condition is false, the message is printed, and the program is aborted.

Traits: TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::triton::GlobalMemory}

Attributes:

AttributeMLIR TypeDescription
message::mlir::StringAttrstring attribute
file::mlir::StringAttrstring attribute
func::mlir::StringAttrstring attribute
line::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

condition

ranked tensor of floating-point or integer or ptr values

tt.atomic_cas (triton::AtomicCASOp)

Atomic cas

Syntax:

operation ::= `tt.atomic_cas` $sem `,` $scope `,` $ptr `,` $cmp `,` $val attr-dict `:`
              functional-type(operands, $result)

compare $cmp with data $old at location $ptr,

if $old == $cmp, store $val to $ptr,

else store $old to $ptr,

return $old

Traits: SameOperandsAndResultEncoding, SameOperandsAndResultShape, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::triton::GlobalMemory}, MemoryEffects::Effect{MemoryEffects::Write on ::mlir::triton::GlobalMemory}

Attributes:

AttributeMLIR TypeDescription
sem::mlir::triton::MemSemanticAttr
allowed 32-bit signless integer cases: 1, 2, 3, 4{{% markdown %}}Enum cases: * relaxed (`RELAXED`) * acquire (`ACQUIRE`) * release (`RELEASE`) * acq_rel (`ACQUIRE_RELEASE`){{% /markdown %}}
scope::mlir::triton::MemSyncScopeAttr
allowed 32-bit signless integer cases: 1, 2, 3{{% markdown %}}Enum cases: * gpu (`GPU`) * cta (`CTA`) * sys (`SYSTEM`){{% /markdown %}}

Operands:

Operand

Description

ptr

ptr or ranked tensor of ptr values

cmp

floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

val

floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

Results:

Result

Description

result

floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

tt.atomic_rmw (triton::AtomicRMWOp)

Atomic rmw

Syntax:

operation ::= `tt.atomic_rmw` $atomic_rmw_op `,` $sem `,` $scope `,` $ptr `,` $val (`,` $mask^)?  attr-dict `:`
              functional-type(operands, $result)

load data at $ptr, do $rmw_op with $val, and store result to $ptr.

return old value at $ptr

Traits: SameOperandsAndResultEncoding, SameOperandsAndResultShape, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::triton::GlobalMemory}, MemoryEffects::Effect{MemoryEffects::Write on ::mlir::triton::GlobalMemory}

Attributes:

AttributeMLIR TypeDescription
atomic_rmw_op::mlir::triton::RMWOpAttr
allowed 32-bit signless integer cases: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10{{% markdown %}}Enum cases: * and (`AND`) * or (`OR`) * xor (`XOR`) * add (`ADD`) * fadd (`FADD`) * max (`MAX`) * min (`MIN`) * umax (`UMAX`) * umin (`UMIN`) * exch (`XCHG`){{% /markdown %}}
sem::mlir::triton::MemSemanticAttr
allowed 32-bit signless integer cases: 1, 2, 3, 4{{% markdown %}}Enum cases: * relaxed (`RELAXED`) * acquire (`ACQUIRE`) * release (`RELEASE`) * acq_rel (`ACQUIRE_RELEASE`){{% /markdown %}}
scope::mlir::triton::MemSyncScopeAttr
allowed 32-bit signless integer cases: 1, 2, 3{{% markdown %}}Enum cases: * gpu (`GPU`) * cta (`CTA`) * sys (`SYSTEM`){{% /markdown %}}

Operands:

Operand

Description

ptr

ptr or ranked tensor of ptr values

val

floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

mask

1-bit signless integer or ranked tensor of 1-bit signless integer values

Results:

Result

Description

result

floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

tt.bitcast (triton::BitcastOp)

Cast between types of the same bitwidth

Syntax:

operation ::= `tt.bitcast` $src attr-dict `:` type($src) `->` type($result)

Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultEncoding, SameOperandsAndResultShape, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

src

floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

Results:

Result

Description

result

floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

tt.broadcast (triton::BroadcastOp)

Broadcast a tensor

Syntax:

operation ::= `tt.broadcast` $src attr-dict `:` type($src) `->` type($result)

For a given tensor, broadcast changes one or more dimensions with size 1 to a new size, e.g. tensor<1x32x1xf32> -> tensor<2x32x4xf32>. You cannot change the size of a non-1 dimension.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultEncoding, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

src

ranked tensor of floating-point or integer or ptr values

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

tt.cat (triton::CatOp)

Concatenate 2 tensors

Syntax:

operation ::= `tt.cat` $lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)

Traits: SameOperandsAndResultElementType, SameTypeOperands, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

lhs

ranked tensor of floating-point or integer or ptr values

rhs

ranked tensor of floating-point or integer or ptr values

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

tt.clampf (triton::ClampFOp)

Clamp operation for floating point types

Syntax:

operation ::= `tt.clampf` $x `,` $min `,` $max `,` `propagateNan` `=` $propagateNan attr-dict `:` type($result)

Clamp operation for floating point types.

The operation takes three arguments: x, min, and max. It returns a tensor of the same shape as x with its values clamped to the range [min, max].

Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultType, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
propagateNan::mlir::triton::PropagateNanAttr
allowed 32-bit signless integer cases: 0, 65535{{% markdown %}}Enum cases: * none (`NONE`) * all (`ALL`){{% /markdown %}}

Operands:

Operand

Description

x

floating-point or ranked tensor of floating-point values

min

floating-point or ranked tensor of floating-point values

max

floating-point or ranked tensor of floating-point values

Results:

Result

Description

result

floating-point or ranked tensor of floating-point values

tt.dot (triton::DotOp)

Dot

Syntax:

operation ::= `tt.dot` $a`,` $b`,` $c (`,` `inputPrecision` `=` $inputPrecision^)? attr-dict `:`
              type($a) `*` type($b) `->` type($d)

$d = matrix_multiply($a, $b) + $c. $inputPrecision describes how to exercise the TC when the inputs are f32. It can be one of: tf32, tf32x3, ieee. tf32: use TC with tf32 ops. tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp ieee: don’t use TC, implement dot in software. If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored.

Traits: AlwaysSpeculatableImplTrait, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
inputPrecision::mlir::triton::InputPrecisionAttr
allowed 32-bit signless integer cases: 0, 1, 2{{% markdown %}}Enum cases: * tf32 (`TF32`) * tf32x3 (`TF32x3`) * ieee (`IEEE`){{% /markdown %}}
maxNumImpreciseAcc::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

a

TensorOrMemDesc instance

b

TensorOrMemDesc instance

c

ranked tensor of floating-point or integer values

Results:

Result

Description

d

ranked tensor of floating-point or integer values

tt.elementwise_inline_asm (triton::ElementwiseInlineAsmOp)

Inline assembly applying an elementwise operation to a group of packed elements.

Syntax:

operation ::= `tt.elementwise_inline_asm` $asm_string attr-dict ($args^ `:` type($args))? `->` type($result)

Runs an inline asm block to generate one or more tensors.

The asm block is given packed_element elements at a time. Exactly which elems it receives is unspecified.

Traits: Elementwise, SameOperandsAndResultEncoding, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
asm_string::mlir::StringAttrstring attribute
constraints::mlir::StringAttrstring attribute
pure::mlir::BoolAttrbool attribute
packed_element::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

args

variadic of floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

Results:

Result

Description

result

variadic of floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

tt.expand_dims (triton::ExpandDimsOp)

Expand_dims

Syntax:

operation ::= `tt.expand_dims` $src attr-dict `:` type($src) `->` type($result)

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

src

ranked tensor of floating-point or integer or ptr values

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

tt.experimental_descriptor_load (triton::ExperimentalDescriptorLoadOp)

Load from descriptor

Syntax:

operation ::= `tt.experimental_descriptor_load` $desc_ptr `[` $indices `]`
              oilist(
              `cacheModifier` `=` $cache |
              `evictionPolicy` `=` $evict
              )
              attr-dict `:` qualified(type($desc_ptr)) `->` type($result)

This operation will be lowered to Nvidia TMA load operation on targets supporting it. desc_ptr is a pointer to the TMA descriptor allocated in global memory. The destination tensor type and shape must match the descriptor otherwise the result is undefined.

This is an escape hatch and is only there for testing/experimenting. This op will be removed in the future.

Traits: TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Read on ::mlir::triton::GlobalMemory}

Attributes:

AttributeMLIR TypeDescription
cache::mlir::triton::CacheModifierAttr
allowed 32-bit signless integer cases: 1, 2, 3, 4, 5, 6{{% markdown %}}Enum cases: * none (`NONE`) * ca (`CA`) * cg (`CG`) * wb (`WB`) * cs (`CS`) * wt (`WT`){{% /markdown %}}
evict::mlir::triton::EvictionPolicyAttr
allowed 32-bit signless integer cases: 1, 2, 3{{% markdown %}}Enum cases: * evict_normal (`NORMAL`) * evict_first (`EVICT_FIRST`) * evict_last (`EVICT_LAST`){{% /markdown %}}

Operands:

Operand

Description

desc_ptr

Pointer type (::mlir::triton::PointerType) in Triton IR type system

indices

variadic of 32-bit signless integer

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

tt.experimental_descriptor_store (triton::ExperimentalDescriptorStoreOp)

Store value based on descriptor

Syntax:

operation ::= `tt.experimental_descriptor_store` $desc_ptr `[` $indices `]` `,` $src
              attr-dict `:` qualified(type($desc_ptr)) `,` type($src)

This operation will be lowered to Nvidia TMA store operation on targets supporting it. desc_ptr is a pointer to the TMA descriptor allocated in global memory. The shape and types of src must match the descriptor otherwise the result is undefined.

This is an escape hatch and is only there for testing/experimenting. This op will be removed in the future.

Traits: TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::triton::GlobalMemory}

Operands:

Operand

Description

desc_ptr

Pointer type (::mlir::triton::PointerType) in Triton IR type system

src

ranked tensor of floating-point or integer or ptr values

indices

variadic of 32-bit signless integer

tt.extern_elementwise (triton::ExternElementwiseOp)

Syntax:

operation ::= `tt.extern_elementwise` operands attr-dict `:` functional-type(operands, $result)

call an external function $symbol implemented in $libpath/$libname with $args return $libpath/$libname:$symbol($args…)

Traits: Elementwise, SameOperandsAndResultEncoding, SameVariadicOperandSize, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
libname::mlir::StringAttrstring attribute
libpath::mlir::StringAttrstring attribute
symbol::mlir::StringAttrstring attribute
pure::mlir::BoolAttrbool attribute

Operands:

Operand

Description

srcs

variadic of floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

Results:

Result

Description

result

floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

tt.fp_to_fp (triton::FpToFpOp)

Floating point casting for custom types

Syntax:

operation ::= `tt.fp_to_fp` $src attr-dict  (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)

Floating point casting for custom types (F8), and non-default rounding modes.

F8 <-> FP16, BF16, FP32, FP64

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultEncoding, SameOperandsAndResultShape, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
rounding::mlir::triton::RoundingModeAttr
allowed 32-bit signless integer cases: 0, 1{{% markdown %}}Enum cases: * rtz (`RTZ`) * rtne (`RTNE`){{% /markdown %}}

Operands:

Operand

Description

src

ranked tensor of floating-point values

Results:

Result

Description

result

ranked tensor of floating-point values

tt.get_num_programs (triton::GetNumProgramsOp)

Syntax:

operation ::= `tt.get_num_programs` $axis attr-dict `:` type($result)

Traits: AlwaysSpeculatableImplTrait, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
axis::mlir::triton::ProgramIDDimAttr
allowed 32-bit signless integer cases: 0, 1, 2{{% markdown %}}Enum cases: * x (`X`) * y (`Y`) * z (`Z`){{% /markdown %}}

Results:

Result

Description

result

32-bit signless integer

tt.get_program_id (triton::GetProgramIdOp)

Syntax:

operation ::= `tt.get_program_id` $axis attr-dict `:` type($result)

Traits: AlwaysSpeculatableImplTrait, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
axis::mlir::triton::ProgramIDDimAttr
allowed 32-bit signless integer cases: 0, 1, 2{{% markdown %}}Enum cases: * x (`X`) * y (`Y`) * z (`Z`){{% /markdown %}}

Results:

Result

Description

result

32-bit signless integer

tt.histogram (triton::HistogramOp)

Return a histgram of the inputs.

Syntax:

operation ::= `tt.histogram` $src attr-dict `:` type($src) `->` type($result)

Return the histogram of the input tensor. The number of bins is equal to the dimension of the output tensor. Each bins has a width of 1 and bins start at 0.

Traits: AlwaysSpeculatableImplTrait, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

src

ranked tensor of integer values

Results:

Result

Description

result

ranked tensor of integer values

tt.int_to_ptr (triton::IntToPtrOp)

Cast int64 to pointer

Syntax:

operation ::= `tt.int_to_ptr` $src attr-dict `:` type($src) `->` type($result)

Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultEncoding, SameOperandsAndResultShape, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

src

64-bit signless integer or tensor of 64-bit signless integer values

Results:

Result

Description

result

ptr or ranked tensor of ptr values

tt.join (triton::JoinOp)

Join two tensors along a new, minor dimension

Syntax:

operation ::= `tt.join` $lhs `,` $rhs attr-dict `:` type($lhs) `->` type($result)

For example, if the two input tensors are 4x8xf32, returns a tensor of shape 4x8x2xf32.

Because Triton tensors always have a power-of-two number of elements, the two input tensors must have the same shape.

Traits: SameTypeOperands, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

lhs

ranked tensor of floating-point or integer or ptr values

rhs

ranked tensor of floating-point or integer or ptr values

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

tt.load (triton::LoadOp)

Load from a tensor of pointers or from a tensor pointer

Syntax:

operation ::= `tt.load` $ptr (`,` $mask^)? (`,` $other^)?
              oilist(
              `cacheModifier` `=` $cache |
              `evictionPolicy` `=` $evict
              )
              attr-dict `:` type($ptr)

Traits: AttrSizedOperandSegments, SameLoadStoreOperandsAndResultEncoding, SameLoadStoreOperandsAndResultShape, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: InferTypeOpInterface, MemoryEffectOpInterface

Attributes:

AttributeMLIR TypeDescription
boundaryCheck::mlir::DenseI32ArrayAttri32 dense array attribute
padding::mlir::triton::PaddingOptionAttr
allowed 32-bit signless integer cases: 1, 2{{% markdown %}}Enum cases: * zero (`PAD_ZERO`) * nan (`PAD_NAN`){{% /markdown %}}
cache::mlir::triton::CacheModifierAttr
allowed 32-bit signless integer cases: 1, 2, 3, 4, 5, 6{{% markdown %}}Enum cases: * none (`NONE`) * ca (`CA`) * cg (`CG`) * wb (`WB`) * cs (`CS`) * wt (`WT`){{% /markdown %}}
evict::mlir::triton::EvictionPolicyAttr
allowed 32-bit signless integer cases: 1, 2, 3{{% markdown %}}Enum cases: * evict_normal (`NORMAL`) * evict_first (`EVICT_FIRST`) * evict_last (`EVICT_LAST`){{% /markdown %}}
isVolatile::mlir::BoolAttrbool attribute

Operands:

Operand

Description

ptr

ptr or ranked tensor of ptr values or ptr

mask

1-bit signless integer or ranked tensor of 1-bit signless integer values

other

floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

Results:

Result

Description

result

floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

tt.make_range (triton::MakeRangeOp)

Make range

Syntax:

operation ::= `tt.make_range` attr-dict `:` type($result)

Returns an 1D int32 tensor.

Values span from $start to $end (exclusive), with step = 1

Traits: AlwaysSpeculatableImplTrait, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
start::mlir::IntegerAttr32-bit signless integer attribute
end::mlir::IntegerAttr32-bit signless integer attribute

Results:

Result

Description

result

ranked tensor of integer values

tt.make_tensor_ptr (triton::MakeTensorPtrOp)

Make a tensor pointer type with meta information of the parent tensor and the block specified

Syntax:

operation ::= `tt.make_tensor_ptr` $base `,` `[` $shape `]` `,` `[` $strides `]` `,` `[` $offsets `]` attr-dict `:` type($result)

tt.make_tensor_ptr takes both meta information of the parent tensor and the block tensor, then it returns a pointer to the block tensor, e.g. returns a type of tt.ptr<tensor<8x8xf16>>.

Traits: AlwaysSpeculatableImplTrait, SameVariadicOperandSize, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
order::mlir::DenseI32ArrayAttri32 dense array attribute

Operands:

Operand

Description

base

ptr

shape

variadic of 64-bit signless integer

strides

variadic of 64-bit signless integer

offsets

variadic of 32-bit signless integer

Results:

Result

Description

result

ptr

tt.mulhiui (triton::MulhiUIOp)

Most significant N bits of the 2N-bit product of two integers

Syntax:

operation ::= `tt.mulhiui` $x `,` $y attr-dict `:` type($x)

Most significant N bits of the 2N-bit product of two integers.

Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultType, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

x

integer or ranked tensor of integer values

y

integer or ranked tensor of integer values

Results:

Result

Description

result

integer or ranked tensor of integer values

tt.precise_divf (triton::PreciseDivFOp)

Precise div for floating point types

Syntax:

operation ::= `tt.precise_divf` $x `,` $y attr-dict `:` type($x)

Precise div for floating point types.

Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultType, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

x

floating-point or ranked tensor of floating-point values

y

floating-point or ranked tensor of floating-point values

Results:

Result

Description

result

floating-point or ranked tensor of floating-point values

tt.precise_sqrt (triton::PreciseSqrtOp)

Precise sqrt for floating point types

Syntax:

operation ::= `tt.precise_sqrt` $x attr-dict `:` type($x)

Precise sqrt for floating point types.

Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultType, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

x

floating-point or ranked tensor of floating-point values

Results:

Result

Description

result

floating-point or ranked tensor of floating-point values

tt.print (triton::PrintOp)

Device-side print, as in CUDA for debugging

Syntax:

operation ::= `tt.print` $prefix attr-dict (`:` $args^ `:` type($args))?

tt.print takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed. format are generated automatically from the arguments.

Traits: TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::triton::GlobalMemory}

Attributes:

AttributeMLIR TypeDescription
prefix::mlir::StringAttrstring attribute
hex::mlir::BoolAttrbool attribute

Operands:

Operand

Description

args

variadic of floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

tt.ptr_to_int (triton::PtrToIntOp)

Cast pointer to int64

Syntax:

operation ::= `tt.ptr_to_int` $src attr-dict `:` type($src) `->` type($result)

Traits: AlwaysSpeculatableImplTrait, Elementwise, SameOperandsAndResultEncoding, SameOperandsAndResultShape, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

src

ptr or ranked tensor of ptr values

Results:

Result

Description

result

64-bit signless integer or tensor of 64-bit signless integer values

tt.reduce (triton::ReduceOp)

Reduction using generic combination algorithm

Traits: AlwaysSpeculatableImplTrait, SameOperandsEncoding, SingleBlock, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute

Operands:

Operand

Description

srcs

variadic of ranked tensor of floating-point or integer or ptr values

Results:

Result

Description

result

variadic of floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

tt.reduce.return (triton::ReduceReturnOp)

Terminator for reduce operator

Syntax:

operation ::= `tt.reduce.return` $result attr-dict `:` type($result)

Traits: AlwaysSpeculatableImplTrait, HasParent<ReduceOp>, ReturnLike, TensorSizeTrait, Terminator, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), RegionBranchTerminatorOpInterface

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

result

variadic of any type

tt.reshape (triton::ReshapeOp)

Reinterpret a tensor to a different shape. It may change elements order if the attribute is set.

Syntax:

operation ::= `tt.reshape` $src attr-dict `:` type($src) `->` type($result)

reinterpret a tensor to a different shape.

If allow_reorder is set the compiler is free to change the order of elements to generate more efficient code.

If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason. The compiler is still free to change it for better performance.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
allow_reorder::mlir::BoolAttrbool attribute
efficient_layout::mlir::UnitAttrunit attribute

Operands:

Operand

Description

src

ranked tensor of floating-point or integer or ptr values

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

tt.scan (triton::ScanOp)

Associative scan using generic combination algorithm

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultEncoding, SameOperandsAndResultShape, SingleBlock, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
axis::mlir::IntegerAttr32-bit signless integer attribute
reverse::mlir::BoolAttrbool attribute

Operands:

Operand

Description

srcs

variadic of ranked tensor of floating-point or integer or ptr values

Results:

Result

Description

result

variadic of ranked tensor of floating-point or integer or ptr values

tt.scan.return (triton::ScanReturnOp)

Terminator for scan operator

Syntax:

operation ::= `tt.scan.return` $result attr-dict `:` type($result)

Traits: AlwaysSpeculatableImplTrait, HasParent<ScanOp>, ReturnLike, TensorSizeTrait, Terminator, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface), RegionBranchTerminatorOpInterface

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

result

variadic of any type

tt.splat (triton::SplatOp)

Splat

Syntax:

operation ::= `tt.splat` $src attr-dict `:` type($src) `->` type($result)

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, SameOperandsAndResultEncoding, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

src

floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

Results:

Result

Description

result

ranked tensor of floating-point or integer or ptr values

tt.split (triton::SplitOp)

Splits a tensor into two, along its last dimension

Syntax:

operation ::= `tt.split` $src attr-dict `:` type($src) `->` type($outLHS)

The input must be a tensor whose last dimension has size 2. Returns two tensors, src[…, 0] and src[…, 1].

For example, if the input shape is 4x8x2xf32, returns two tensors of shape 4x8xf32.

Traits: TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Operands:

Operand

Description

src

ranked tensor of floating-point or integer or ptr values

Results:

Result

Description

outLHS

ranked tensor of floating-point or integer or ptr values

outRHS

ranked tensor of floating-point or integer or ptr values

tt.store (triton::StoreOp)

Store by a tensor of pointers or by a tensor pointer

Syntax:

operation ::= `tt.store` $ptr `,` $value (`,` $mask^)?
              oilist(`cacheModifier` `=` $cache | `evictionPolicy` `=` $evict)
              attr-dict `:` type($ptr)

Traits: SameLoadStoreOperandsEncoding, SameLoadStoreOperandsShape, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: MemoryEffectOpInterface (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{MemoryEffects::Write on ::mlir::triton::GlobalMemory}

Attributes:

AttributeMLIR TypeDescription
boundaryCheck::mlir::DenseI32ArrayAttri32 dense array attribute
cache::mlir::triton::CacheModifierAttr
allowed 32-bit signless integer cases: 1, 2, 3, 4, 5, 6{{% markdown %}}Enum cases: * none (`NONE`) * ca (`CA`) * cg (`CG`) * wb (`WB`) * cs (`CS`) * wt (`WT`){{% /markdown %}}
evict::mlir::triton::EvictionPolicyAttr
allowed 32-bit signless integer cases: 1, 2, 3{{% markdown %}}Enum cases: * evict_normal (`NORMAL`) * evict_first (`EVICT_FIRST`) * evict_last (`EVICT_LAST`){{% /markdown %}}

Operands:

Operand

Description

ptr

ptr or ranked tensor of ptr values or ptr

value

floating-point or ranked tensor of floating-point values or integer or ranked tensor of integer values or ptr or ranked tensor of ptr values or ptr

mask

1-bit signless integer or ranked tensor of 1-bit signless integer values

tt.trans (triton::TransOp)

Rearrange the dimensions of a tensor

Syntax:

operation ::= `tt.trans` $src attr-dict `:` type($src) `->` type($result)

For example, given a tensor x with shape [1,2,4], transpose(x) with order=[2,0,1] rearranges the tensor to have shape [4,1,2].

Although this op is called “trans”, it implements both tl.trans() and tl.permute(). (“permute” might be a better name, but it’s called “trans” because originally it only supported 2D tensors.)

Implementation note on encodings:

In the TritonGPU dialect (and probably others), an encoding is chosen for this op’s output so it’s a nop from the perspective of code generation.

For example, suppose tensor x has an encoding such that GPU thread [i,j,k] has a register containing element [i,j,k] of the tensor. Now we transpose x with order [2,1,0], i.e. we reverse the order of its dimensions. In TritonGPU, we will choose a layout for the output of the transpose so that GPU thread [i,j,k] has element [k,j,i] of transpose(x). But this is the same element it had before! All we’ve done is “rename” the element that thread [i,j,k] has.

The “real” transpose – i.e. moving data between GPU threads – occurs in convertLayout ops that appear before and/or after the operation.

We do this so that you can chain multiple data-movement ops (e.g. transpose+reshape+concat) without going to shared memory after each one.

Traits: AlwaysSpeculatableImplTrait, SameOperandsAndResultElementType, TensorSizeTrait, VerifyTensorLayoutsTrait

Interfaces: ConditionallySpeculatable, InferTypeOpInterface, NoMemoryEffect (MemoryEffectOpInterface)

Effects: MemoryEffects::Effect{}

Attributes:

AttributeMLIR TypeDescription
order::mlir::DenseI32ArrayAttri32 dense array attribute

Operands:

Operand

Description

src

TensorOrMemDesc instance

Results:

Result

Description

result

TensorOrMemDesc instance