triton.language.argmax

triton.language.argmax(input, axis, tie_break_left=True)

Returns the maximum index of all elements in the input tensor along the provided axis

Parameters:
  • input – the input values

  • axis – the dimension along which the reduction should be done

  • tie_break_left – if true, return the left-most indices in case of ties for values that aren’t NaN