triton.language

Programming Model

tensor

Represents an N-dimensional array of values or pointers.

program_id

Returns the id of the current program instance along the given axis.

num_programs

Returns the number of program instances launched along the given axis.

Creation Ops

arange

Returns contiguous values within the half-open interval [start, end).

cat

Concatenate the given blocks

full

Returns a tensor filled with the scalar value for the given shape and dtype.

zeros

Returns a tensor filled with the scalar value 0 for the given shape and dtype.

zeros_like

Creates a tensor of zeros with the same shape and type as a given tensor.

Shape Manipulation Ops

broadcast

Tries to broadcast the two given blocks to a common compatible shape.

broadcast_to

Tries to broadcast the given tensor to a new shape.

expand_dims

Expand the shape of a tensor, by inserting new length-1 dimensions.

interleave

Interleaves the values of two tensors along their last dimension.

join

Join the given tensors in a new, minor dimension.

permute

Permutes the dimensions of a tensor.

ravel

Returns a contiguous flattened view of x.

reshape

Returns a tensor with the same number of elements as input but with the provided shape.

split

Split a tensor in two along its last dim, which must have size 2.

trans

Permutes the dimensions of a tensor.

view

Returns a tensor with the same elements as input but a different shape.

Linear Algebra Ops

dot

Returns the matrix product of two blocks.

Memory/Pointer Ops

load

Return a tensor of data whose values are loaded from memory at location defined by pointer:

store

Store a tensor of data into memory locations defined by pointer.

make_block_ptr

Returns a pointer to a block in a parent tensor

advance

Advance a block pointer

Indexing Ops

flip

Flips a tensor x along the dimension dim.

where

Returns a tensor of elements from either x or y, depending on condition.

swizzle2d

Transforms indices of a row-major size_i * size_j matrix into those of one where the indices are row-major for each group of size_j rows.

Math Ops

abs

Computes the element-wise absolute value of x.

cdiv

Computes the ceiling division of x by div

clamp

Clamps the input tensor x within the range [min, max].

cos

Computes the element-wise cosine of x.

div_rn

Computes the element-wise precise division (rounding to nearest) of x and y.

erf

Computes the element-wise error function of x.

exp

Computes the element-wise exponential of x.

exp2

Computes the element-wise exponential (base 2) of x.

fdiv

Computes the element-wise fast division of x and y.

floor

Computes the element-wise floor of x.

fma

Computes the element-wise fused multiply-add of x, y, and z.

log

Computes the element-wise natural logarithm of x.

log2

Computes the element-wise logarithm (base 2) of x.

maximum

Computes the element-wise maximum of x and y.

minimum

Computes the element-wise minimum of x and y.

rsqrt

Computes the element-wise inverse square root of x.

sigmoid

Computes the element-wise sigmoid of x.

sin

Computes the element-wise sine of x.

softmax

Computes the element-wise softmax of x.

sqrt

Computes the element-wise fast square root of x.

sqrt_rn

Computes the element-wise precise square root (rounding to nearest) of x.

umulhi

Computes the element-wise most significant N bits of the 2N-bit product of x and y.

Reduction Ops

argmax

Returns the maximum index of all elements in the input tensor along the provided axis

argmin

Returns the minimum index of all elements in the input tensor along the provided axis

max

Returns the maximum of all elements in the input tensor along the provided axis

min

Returns the minimum of all elements in the input tensor along the provided axis

reduce

Applies the combine_fn to all elements in input tensors along the provided axis

sum

Returns the sum of all elements in the input tensor along the provided axis

xor_sum

Returns the xor sum of all elements in the input tensor along the provided axis

Scan/Sort Ops

associative_scan

Applies the combine_fn to each elements with a carry in input tensors along the provided axis and update the carry

cumprod

Returns the cumprod of all elements in the input tensor along the provided axis

cumsum

Returns the cumsum of all elements in the input tensor along the provided axis

histogram

computes an histogram based on input tensor with num_bins bins, the bins have a width of 1 and start at 0.

sort

This function can also be called as a member function on tensor, as x.sort(...) instead of sort(x, ...).

Atomic Ops

atomic_add

Performs an atomic add at the memory location specified by pointer.

atomic_and

Performs an atomic logical and at the memory location specified by pointer.

atomic_cas

Performs an atomic compare-and-swap at the memory location specified by pointer.

atomic_max

Performs an atomic max at the memory location specified by pointer.

atomic_min

Performs an atomic min at the memory location specified by pointer.

atomic_or

Performs an atomic logical or at the memory location specified by pointer.

atomic_xchg

Performs an atomic exchange at the memory location specified by pointer.

atomic_xor

Performs an atomic logical xor at the memory location specified by pointer.

Random Number Generation

randint4x

Given a seed scalar and an offset block, returns four blocks of random int32.

randint

Given a seed scalar and an offset block, returns a single block of random int32.

rand

Given a seed scalar and an offset block, returns a block of random float32 in \(U(0, 1)\).

randn

Given a seed scalar and an offset block, returns a block of random float32 in \(\mathcal{N}(0, 1)\).

Iterators

range

Iterator that counts upward forever.

static_range

Iterator that counts upward forever.

Inline Assembly

inline_asm_elementwise

Execute inline assembly over a tensor.

Compiler Hint Ops

debug_barrier

Insert a barrier to synchronize all threads in a block.

max_constancy

Let the compiler know that the value first values in input are constant.

max_contiguous

Let the compiler know that the value first values in input are contiguous.

multiple_of

Let the compiler know that the values in input are all multiples of value.

Debug Ops

static_print

Print the values at compile time.

static_assert

Assert the condition at compile time.

device_print

Print the values at runtime from the device.

device_assert

Assert the condition at runtime from the device.