triton.language.cumsum¶
- triton.language.cumsum(input, axis=0, reverse=False, dtype: constexpr | None = None)¶
Returns the cumsum of all elements in the
input
tensor along the providedaxis
- Parameters:
input (Tensor) – the input values
axis (int) – the dimension along which the scan should be done
reverse (bool) – if true, the scan is performed in the reverse direction
dtype (tl.dtype) – the desired data type of the returned tensor. If specified, the input tensor is casted to
dtype
before the operation is performed. If not specified, small integer types (< 32 bits) are upcasted to prevent overflow. Note thattl.bfloat16
inputs are automatically promoted totl.float32
.
This function can also be called as a member function on
tensor
, asx.cumsum(...)
instead ofcumsum(x, ...)
.