triton.language.cast

triton.language.cast(input, dtype: dtype, fp_downcast_rounding: str | None = None, bitcast: bool = False)

Casts a tensor to the given dtype.

Parameters:
  • dtype (tl.dtype) – The target data type.

  • fp_downcast_rounding (str, optional) – The rounding mode for downcasting floating-point values. This parameter is only used when self is a floating-point tensor and dtype is a floating-point type with a smaller bitwidth. Supported values are "rtne" (round to nearest, ties to even) and "rtz" (round towards zero).

  • bitcast (bool, optional) – If true, the tensor is bitcasted to the given dtype, instead of being numerically casted.

This function can also be called as a member function on tensor, as x.cast(...) instead of cast(x, ...).