triton.language.cumsum

triton.language.cumsum(input, axis=0)

Returns the cumsum of all elements in the input tensor along the provided axis

Parameters:
  • input – the input values

  • axis – the dimension along which the scan should be done