triton.language.associative_scan¶
- triton.language.associative_scan(input, axis, combine_fn, reverse=False)¶
Applies the combine_fn to each elements with a carry in
input
tensors along the providedaxis
and update the carry- Parameters:
input (Tensor) – the input tensor, or tuple of tensors
axis (int) – the dimension along which the reduction should be done
combine_fn (Callable) – a function to combine two groups of scalar tensors (must be marked with @triton.jit)
reverse (bool) – whether to apply the associative scan in the reverse direction along axis
This function can also be called as a member function on
tensor
, asx.associative_scan(...)
instead ofassociative_scan(x, ...)
.