triton.language.cumsum

triton.language.cumsum(input, axis=0, reverse=False, dtype: constexpr | None = None)

Returns the cumsum of all elements in the input tensor along the provided axis

Parameters:
  • input (Tensor) – the input values

  • axis (int) – the dimension along which the scan should be done

  • reverse (bool) – if true, the scan is performed in the reverse direction

  • dtype (tl.dtype) – the desired data type of the returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. If not specified, small integer types (< 32 bits) are upcasted to prevent overflow. Note that tl.bfloat16 inputs are automatically promoted to tl.float32.

This function can also be called as a member function on tensor, as x.cumsum(...) instead of cumsum(x, ...).