triton.language.trans

triton.language.trans(input: tensor, *dims)

Permutes the dimensions of a tensor.

If no permutation is specified, tries to do a (1,0) permutation, i.e. tries to transpose a 2D tensor.

Parameters:
  • input – The input tensor.

  • dims – The desired ordering of dimensions. For example, (2, 1, 0) reverses the order dims in a a 3D tensor.

dims can be passed as a tuple or as individual parameters:

# These are equivalent
trans(x, (2, 1, 0))
trans(x, 2, 1, 0)

permute() is equivalent to this function, except it doesn’t have the special case when no permutation is specified.

This function can also be called as a member function on tensor, as x.trans(...) instead of trans(x, ...).