triton.language.reduce¶
- triton.language.reduce(input, axis, combine_fn, keep_dims=False)¶
Applies the combine_fn to all elements in
input
tensors along the providedaxis
- Parameters:
input (Tensor) – the input tensor, or tuple of tensors
axis (int | None) – the dimension along which the reduction should be done. If None, reduce all dimensions
combine_fn (Callable) – a function to combine two groups of scalar tensors (must be marked with @triton.jit)
keep_dims (bool) – if true, keep the reduced dimensions with length 1
This function can also be called as a member function on
tensor
, asx.reduce(...)
instead ofreduce(x, ...)
.