triton.language.topk

triton.language.topk(x, k: constexpr, dim: constexpr | None = None)

Returns the k largest elements of the input tensor along the specified dimension.

The elements are returned in sorted order (largest first).

Parameters:
  • x (Tensor) – The input tensor.

  • k (int) – The number of top elements to return. Must be a power of two.

  • dim (int, optional) – The dimension along which to find the top k elements. If None, uses the last dimension. Currently only the last dimension is supported.

Returns:

A tensor containing the k largest elements along the specified dimension.

Return type:

Tensor

Example:

# Get top 4 elements from a 1D tensor
x = tl.arange(0, 16)
top4 = tl.topk(x, 4)  # Returns [15, 14, 13, 12]