triton.language.reduce

triton.language.reduce(input, axis, combine_fn)

Applies the combine_fn to all elements in input tensors along the provided axis

Parameters:
  • input – the input tensor, or tuple of tensors

  • axis – the dimension along which the reduction should be done

  • combine_fn – a function to combine two groups of scalar tensors (must be marked with @triton.jit)