triton.language.dot_scaled¶
- triton.language.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, out_dtype=triton.language.float32)¶
Returns the matrix product of two blocks in microscaling format. lhs and rhs use microscaling formats described here: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf :param lhs: The first tensor to be multiplied. :type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type. :param lhs_scale: Scale factor for lhs tensor. :type lhs_scale: e8m0 type represented as an uint8 tensor. :param lhs_format: format of the lhs tensor. Available formats: {
e2m1
,e4m3
,e5m2
,bf16
}. :type lhs_format: str :param rhs: The second tensor to be multiplied. :type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type. :param rhs_scale: Scale factor for rhs tensor. :type rhs_scale: e8m0 type represented as an uint8 tensor. :param rhs_format: format of the rhs tensor. Available formats: {e2m1
,e4m3
,e5m2
,bf16
}. :type rhs_format: str :param acc: The accumulator tensor. If not None, the result is added to this tensor.