triton.language.reduce

triton.language.reduce(input, axis, combine_fn, keep_dims=False)

Applies the combine_fn to all elements in input tensors along the provided axis

Parameters:
  • input – the input tensor, or tuple of tensors

  • axis – the dimension along which the reduction should be done. If None, reduce all dimensions

  • combine_fn – a function to combine two groups of scalar tensors (must be marked with @triton.jit)

  • keep_dims – if true, keep the reduced dimensions with length 1