triton.language

Programming Model

program_id

Returns the id of the current program instance along the given axis.

num_programs

Returns the number of program instances launched along the given axis.

Creation Ops

arange

Returns contiguous values within the left-closed and right-open interval [start, end).

cat

Concatenate the given blocks

full

Returns a tensor filled with the scalar value for the given shape and dtype.

zeros

Returns a tensor filled with the scalar value 0 for the given shape and dtype.

Shape Manipulation Ops

broadcast

Tries to broadcast the two given blocks to a common compatible shape.

broadcast_to

Tries to broadcast the given tensor to a new shape.

expand_dims

Expand the shape of a tensor, by inserting new length-1 dimensions.

ravel

Returns a contiguous flattened view of x.

reshape

Returns a tensor with the same number of elements as input but with the provided shape.

trans

Returns a transposed tensor.

view

Returns a tensor with the same elements as input but a different shape.

Linear Algebra Ops

dot

Returns the matrix product of two blocks.

Memory Ops

load

Return a tensor of data whose values are loaded from memory at location defined by pointer:

store

Store a tensor of data into memory locations defined by pointer:

Indexing Ops

where

Returns a tensor of elements from either x or y, depending on condition.

Math Ops

abs

Computes the element-wise absolute value of x.

exp

Computes the element-wise exponential of x.

log

Computes the element-wise natural logarithm of x.

fdiv

Returns a floating-point resultant tensor of dividing x by y.

cos

Computes the element-wise cosine of x.

sin

Computes the element-wise sine of x.

sqrt

Computes the element-wise square root of x.

sigmoid

Computes the element-wise sigmoid of x.

softmax

Computes the element-wise softmax of x.

umulhi

Returns the most significant 32 bits of the product of x and y.

Reduction Ops

argmax

Returns the maximum index of all elements in the input tensor along the provided axis

argmin

Returns the minimum index of all elements in the input tensor along the provided axis

max

Returns the maximum of all elements in the input tensor along the provided axis

min

Returns the minimum of all elements in the input tensor along the provided axis

reduce

Applies the combine_fn to all elements in input tensors along the provided axis

sum

Returns the sum of all elements in the input tensor along the provided axis

xor_sum

Returns the xor sum of all elements in the input tensor along the provided axis

Scan Ops

associative_scan

Applies the combine_fn to each elements with a carry in input tensors along the provided axis and update the carry

cumsum

Returns the cumsum of all elements in the input tensor along the provided axis

cumprod

Returns the cumprod of all elements in the input tensor along the provided axis

Atomic Ops

atomic_add

Performs an atomic add at the memory location specified by pointer.

atomic_cas

Performs an atomic compare-and-swap at the memory location specified by pointer.

atomic_max

Performs an atomic max at the memory location specified by pointer.

atomic_min

Performs an atomic min at the memory location specified by pointer.

atomic_xchg

Performs an atomic exchange at the memory location specified by pointer.

Comparison ops

minimum

Computes the element-wise minimum of x and y.

maximum

Computes the element-wise maximum of x and y.

Random Number Generation

randint4x

Given a seed scalar and an offset block, returns four blocks of random int32.

randint

Given a seed scalar and an offset block, returns a single block of random int32.

rand

Given a seed scalar and an offset block, returns a block of random float32 in \(U(0, 1)\).

randn

Given a seed scalar and an offset block, returns a block of random float32 in \(\mathcal{N}(0, 1)\).

Compiler Hint Ops

debug_barrier

Insert a barrier to synchronize all threads in a block.

max_constancy

Let the compiler knows that the value first values in input are constant.

max_contiguous

Let the compiler knows that the value first values in input are contiguous.

multiple_of

Let the compiler knows that the values in input are all multiples of value.

Debug Ops

static_print

Print the values at compile time.

static_assert

Assert the condition at compile time.

device_print

Print the values at runtime from the device.

device_assert

Assert the condition at runtime from the device.

Iterators

static_range

Iterator that counts upward forever.

multiple_of

Let the compiler knows that the values in input are all multiples of value.