triton.language.associative_scan

triton.language.associative_scan(input, axis, combine_fn, reverse=False)

Applies the combine_fn to each elements with a carry in input tensors along the provided axis and update the carry

Parameters:
  • input (Tensor) – the input tensor, or tuple of tensors

  • axis (int) – the dimension along which the reduction should be done

  • combine_fn (Callable) – a function to combine two groups of scalar tensors (must be marked with @triton.jit)

  • reverse (bool) – whether to apply the associative scan in the reverse direction along axis

This function can also be called as a member function on tensor, as x.associative_scan(...) instead of associative_scan(x, ...).