triton.experimental.gluon.language.nvidia.hopper.warpgroup_mma
- triton.experimental.gluon.language.nvidia.hopper.warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_acc=None, is_async=False, _semantic=None)
Perform warpgroup MMA (Tensor Core) operations. acc = a * b + (acc if use_acc else 0)
- Parameters:
a (tensor or shared_memory_descriptor) – Left hand side operand.
b (shared_memory_descriptor) – Right hand side operand.
acc (tensor) – Accumulator tensor.
use_acc (bool) – Whether to use the initial value of the accumulator. Defaults to True.
precision (str, optional) – Dot input precision. Defaults to builder default.
max_num_imprecise_acc (int) – Max imprecise accumulations. Used for fp8 -> fp32 dot. Determines how many accumulation are done in limited precision. Defaults to None, which means no upcasting is done.
is_async (bool) – Whether operation is asynchronous. Defaults to False.
- Returns:
Returns the result if synchronous, or a token to load the value once computed if asynchronous.
- Return type:
tensor or warpgroup_mma_accumulator