triton.experimental.gluon.language.map_elementwise

triton.experimental.gluon.language.map_elementwise(scalar_fn: Callable[[...], Tuple[tensor, ...]], *args: tensor, pack=1, _semantic=None, _generator=None)

Map a scalar function over a tensor.

The input tensors args are implicitly broadcasted to the same shape.

This may be useful in allowing control flow over single elements in a tensor, for example a multi-branch function where one branch is more expensive. With tl.where you are forced to calculate both sides of the branch, but with an if we only execute one side.

@triton.jit
def selu_scalar(x, alpha):
    if x > 0:
        return a
    else:
        return alpha * (tl.exp(x) - 1)

@triton.jit
def selu(x, alpha):
    return tl.map_elementwise(selu_scalar, x, alpha)
Parameters:
  • scalar_fn – the function to map over.

  • pack – the number of elements to be processed by one function call.

Returns:

one tensor or a tuple of tensors, depending on the mapped function.