triton.language.reduce

triton.language.reduce(input, axis, combine_fn, keep_dims=False)

Applies the combine_fn to all elements in input tensors along the provided axis

Parameters:
  • input (Tensor) – the input tensor, or tuple of tensors

  • axis (int | None) – the dimension along which the reduction should be done. If None, reduce all dimensions

  • combine_fn (Callable) – a function to combine two groups of scalar tensors (must be marked with @triton.jit)

  • keep_dims (bool) – if true, keep the reduced dimensions with length 1

This function can also be called as a member function on tensor, as x.reduce(...) instead of reduce(x, ...).