triton.language.trans

triton.language.trans(input: tensor, *dims)

Permutes the dimensions of a tensor.

If the parameter dims is not specified, the function defaults to a (1,0) permutation, effectively transposing a 2D tensor.

Parameters:
  • input – The input tensor.

  • dims – The desired ordering of dimensions. For example, (2, 1, 0) reverses the order dims in a a 3D tensor.

dims can be passed as a tuple or as individual parameters:

# These are equivalent
trans(x, (2, 1, 0))
trans(x, 2, 1, 0)

permute() is equivalent to this function, except it doesn’t have the special case when no permutation is specified.

This function can also be called as a member function on tensor, as x.trans(...) instead of trans(x, ...).