triton.language.expand_dims

triton.language.expand_dims(input, axis)

Expand the shape of a tensor, by inserting new length-1 dimensions.

Axis indices are with respect to the resulting tensor, so result.shape[axis] will be 1 for each axis.

Parameters:
  • input (tl.tensor) – The input tensor.

  • axis (int | Sequence[int]) – The indices to add new axes

This function can also be called as a member function on tensor, as x.expand_dims(...) instead of expand_dims(x, ...).